import asyncio import subprocess import csv import gzip import logging import shutil import tempfile import sys from dataclasses import dataclass from urllib.parse import urlparse from datetime import datetime from pathlib import Path from typing import Any import aiosqlite import aiohttp log = logging.getLogger(__name__) @dataclass class Context: session: aiohttp.ClientSession db: Any @dataclass class Tag: id: int name: str category: int post_count: int @dataclass class Post: id: int uploader_id: int created_at: int md5: str source: str rating: str tag_string: str is_deleted: int is_pending: int is_flagged: int score: int up_score: int down_score: int is_rating_locked: int def e621_bool(text: str) -> bool: return text == "t" async def main_with_ctx(ctx, wanted_date): urls = { "tags": f"https://e621.net/db_export/tags-{wanted_date}.csv.gz", "posts": f"https://e621.net/db_export/posts-{wanted_date}.csv.gz", } output_compressed_paths = {} output_fd = {} for url_type, url in urls.items(): parsed = urlparse(url) parsed_path = Path(parsed.path) output_path = Path.cwd() / parsed_path.name output_compressed_paths[url_type] = output_path for url_type, url in urls.items(): output_path = output_compressed_paths[url_type] if output_path.exists(): log.info("file %s already exists, ignoring", output_path) continue log.info("downloading %r into %s", url, output_path) async with ctx.session.get(url) as resp: assert resp.status == 200 total_length = int(resp.headers["content-length"]) downloaded_bytes = 0 download_ratio = 0 log.info("to download %d bytes", total_length) with tempfile.TemporaryFile() as temp_fd: async for chunk in resp.content.iter_chunked(8192): temp_fd.write(chunk) downloaded_bytes += len(chunk) new_download_ratio = round((downloaded_bytes / total_length) * 100) if new_download_ratio != download_ratio: log.info("download at %d%%", download_ratio) download_ratio = new_download_ratio temp_fd.seek(0) # write to output log.info("copying temp to output") with output_path.open(mode="wb") as output_fdout: shutil.copyfileobj(temp_fd, output_fdout) # decompress for url_type, _url in urls.items(): input_path = output_compressed_paths[url_type] output_fd[url_type] = gzip.open(input_path, "rt") # now that everythings downloaded, compile the db await ctx.db.executescript( """ CREATE TABLE IF NOT EXISTS db_version ( version int primary key ) STRICT; CREATE TABLE IF NOT EXISTS tags ( id int primary key, name text unique, category int not null, post_count int not null ) STRICT; CREATE TABLE IF NOT EXISTS posts ( id int primary key, uploader_id int not null, created_at int not null, md5 text unique not null, source text not null, rating text not null, tag_string text not null, is_deleted int not null, is_pending int not null, is_flagged int not null, score int not null, up_score int not null, down_score int not null, is_rating_locked int not null ) STRICT; """ ) await ctx.db.commit() work_done = False tags_csv_fd = output_fd["tags"] work_done = work_done or await work_tags(ctx, tags_csv_fd) log.info("going to process posts") posts_csv_fd = output_fd["posts"] any_posts = await work_posts(ctx, posts_csv_fd) work_done = work_done or any_posts if work_done: log.info("vacuuming db...") await ctx.db.execute("vacuum") log.info("database built") async def work_tags(ctx, tags_csv_fd): tag_count_rows = await ctx.db.execute_fetchall("select count(*) from tags") tag_count = tag_count_rows[0][0] log.info("already have %d tags", tag_count) line_count = 0 for line in tags_csv_fd: line_count += 1 line_count -= 1 # remove header log.info("%d tags to import", line_count) if line_count == tag_count: log.info("same counts, not going to reimport") return False else: tags_csv_fd.seek(0) tags_reader = csv.reader(tags_csv_fd) assert len(next(tags_reader)) == 4 processed_count = 0 processed_ratio = 0 for row in tags_reader: tag = Tag(int(row[0]), row[1], int(row[2]), int(row[3])) await ctx.db.execute( "insert into tags (id, name, category, post_count) values (?, ?, ?, ?) on conflict do nothing", (tag.id, tag.name, tag.category, tag.post_count), ) processed_count += 1 new_processed_ratio = round((processed_count / line_count) * 100) if new_processed_ratio != processed_ratio: log.info("tags processed at %d%%", processed_ratio) processed_ratio = new_processed_ratio log.info("tags commit...") await ctx.db.commit() log.info("tags done...") return True async def work_posts(ctx, posts_csv_fd): post_count_rows = await ctx.db.execute_fetchall("select count(*) from posts") post_count = post_count_rows[0][0] log.info("already have %d posts", post_count) line_count = 0 counter_reader = csv.reader(posts_csv_fd) for _row in counter_reader: line_count += 1 line_count -= 1 # remove header log.info("%d posts to import", line_count) if line_count == post_count: log.info("already imported everything, skipping") return False else: posts_csv_fd.seek(0) posts_reader = csv.DictReader(posts_csv_fd) processed_count = 0 processed_ratio = 0.0 for row in posts_reader: created_at_str = row["created_at"] created_at = datetime.strptime( created_at_str[: created_at_str.find(".")], "%Y-%m-%d %H:%M:%S" ) post = Post( id=int(row["id"]), uploader_id=int(row["uploader_id"]), created_at=int(created_at.timestamp()), md5=row["md5"], source=row["source"], rating=row["rating"], tag_string=row["tag_string"], is_deleted=e621_bool(row["is_deleted"]), is_pending=e621_bool(row["is_pending"]), is_flagged=e621_bool(row["is_flagged"]), score=int(row["score"]), up_score=int(row["up_score"]), down_score=int(row["down_score"]), is_rating_locked=e621_bool(row["is_rating_locked"]), ) await ctx.db.execute( """ insert into posts ( id, uploader_id, created_at, md5, source, rating, tag_string, is_deleted, is_pending, is_flagged, score, up_score, down_score, is_rating_locked ) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?) on conflict do nothing """, ( post.id, post.uploader_id, post.created_at, post.md5, post.source, post.rating, post.tag_string, post.is_deleted, post.is_pending, post.is_flagged, post.score, post.up_score, post.down_score, post.is_rating_locked, ), ) processed_count += 1 new_processed_ratio = round((processed_count / line_count) * 100, 2) if str(new_processed_ratio) != str(processed_ratio): log.info("posts processed at %.2f%%", processed_ratio) processed_ratio = new_processed_ratio log.info("posts commit...") await ctx.db.commit() log.info("posts done") return True async def main(): wanted_date = sys.argv[1] async with aiohttp.ClientSession() as session: async with aiosqlite.connect("./e621.db") as db: ctx = Context(session, db) await main_with_ctx(ctx, wanted_date) if __name__ == "__main__": logging.basicConfig(level=logging.INFO) csv.field_size_limit(2**19) asyncio.run(main())