e621_api_cloner/build_database.py

179 lines
5.3 KiB
Python

import asyncio
import csv
import gzip
import logging
import shutil
import tempfile
import sys
from dataclasses import dataclass
from urllib.parse import urlparse
from pathlib import Path
from typing import Any
import aiosqlite
import aiohttp
log = logging.getLogger(__name__)
@dataclass
class Context:
session: aiohttp.ClientSession
db: Any
@dataclass
class Tag:
id: int
name: str
category: int
post_count: int
async def main_with_ctx(ctx, wanted_date):
urls = {
"tags": f"https://e621.net/db_export/tags-{wanted_date}.csv.gz",
# "posts": f"https://e621.net/db_export/posts-{wanted_date}.csv.gz",
}
output_compressed_paths = {}
output_uncompressed_paths = {}
for url_type, url in urls.items():
parsed = urlparse(url)
parsed_path = Path(parsed.path)
output_path = Path.cwd() / parsed_path.name
output_compressed_paths[url_type] = output_path
original_name, original_extension, _gz = parsed_path.name.split(".")
output_uncompressed = Path.cwd() / f"{original_name}.{original_extension}"
output_uncompressed_paths[url_type] = output_uncompressed
for url_type, url in urls.items():
output_path = output_compressed_paths[url_type]
if output_path.exists():
log.info("file %s already exists, ignoring", output_path)
continue
log.info("downloading %r into %s", url, output_path)
async with ctx.session.get(url) as resp:
assert resp.status == 200
total_length = int(resp.headers["content-length"])
downloaded_bytes = 0
download_ratio = 0
log.info("to download %d bytes", total_length)
with tempfile.TemporaryFile() as temp_fd:
async for chunk in resp.content.iter_chunked(8192):
temp_fd.write(chunk)
downloaded_bytes += len(chunk)
new_download_ratio = round((downloaded_bytes / total_length) * 100)
if new_download_ratio != download_ratio:
log.info("download at %d%%", download_ratio)
download_ratio = new_download_ratio
temp_fd.seek(0)
# write to output
log.info("copying temp to output")
with output_path.open(mode="wb") as output_fd:
shutil.copyfileobj(temp_fd, output_fd)
# decompress
for url_type, _url in urls.items():
input_path = output_compressed_paths[url_type]
output_path = output_uncompressed_paths[url_type]
if output_path.exists():
log.info("decompressed file %s already exists, ignoring", output_path)
continue
log.info("decompressing %s into %s", input_path.name, output_path.name)
with gzip.open(input_path, "rb") as in_fd:
with output_path.open(mode="wb") as out_fd:
shutil.copyfileobj(in_fd, out_fd)
# now that everythings downloaded, compile the db
await ctx.db.executescript(
"""
CREATE TABLE IF NOT EXISTS db_version (
version int primary key
) STRICT;
CREATE TABLE IF NOT EXISTS tags (
id int primary key,
name text unique,
category int not null,
post_count int not null
) STRICT;
CREATE TABLE IF NOT EXISTS posts (
id int primary key,
uploader_id int not null,
created_at int not null,
md5 text unique not null,
source text not null,
rating text not null,
tag_string text not null,
is_deleted int not null,
is_pending int not null,
is_flagged int not null,
score int not null,
up_score int not null,
down_score int not null,
is_rating_locked int not null
) STRICT;
"""
)
await ctx.db.commit()
with output_uncompressed_paths["tags"].open(
mode="r", encoding="utf-8"
) as tags_csv_fd:
line_count = 0
for line in tags_csv_fd:
line_count += 1
line_count -= 1 # remove header
log.info("%d tags to import", line_count)
tags_csv_fd.seek(0)
tags_reader = csv.reader(tags_csv_fd)
assert len(next(tags_reader)) == 4
processed_count = 0
processed_ratio = 0
for row in tags_reader:
tag = Tag(int(row[0]), row[1], int(row[2]), int(row[3]))
await ctx.db.execute(
"insert into tags (id, name, category, post_count) values (?, ?, ?, ?)",
(tag.id, tag.name, tag.category, tag.post_count),
)
processed_count += 1
new_processed_ratio = round((processed_count / line_count) * 100)
if new_processed_ratio != processed_ratio:
log.info("processed at %d%%", processed_ratio)
processed_ratio = new_processed_ratio
log.info("done")
await ctx.db.commit()
async def main():
wanted_date = sys.argv[1]
async with aiohttp.ClientSession() as session:
async with aiosqlite.connect("./e621.db") as db:
ctx = Context(session, db)
await main_with_ctx(ctx, wanted_date)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(main())