import asyncio import subprocess import csv import gzip import logging import shutil import tempfile import sys from dataclasses import dataclass from urllib.parse import urlparse from datetime import datetime from pathlib import Path from typing import Any import aiosqlite import aiohttp log = logging.getLogger(__name__) @dataclass class Context: session: aiohttp.ClientSession db: Any @dataclass class Tag: id: int name: str category: int post_count: int @dataclass class Post: id: int uploader_id: int created_at: int md5: str source: str rating: str tag_string: str is_deleted: int is_pending: int is_flagged: int score: int up_score: int down_score: int is_rating_locked: int def e621_bool(text: str) -> bool: return text == "t" async def main_with_ctx(ctx, wanted_date): urls = { "tags": f"https://e621.net/db_export/tags-{wanted_date}.csv.gz", "posts": f"https://e621.net/db_export/posts-{wanted_date}.csv.gz", } output_compressed_paths = {} output_uncompressed_paths = {} for url_type, url in urls.items(): parsed = urlparse(url) parsed_path = Path(parsed.path) output_path = Path.cwd() / parsed_path.name output_compressed_paths[url_type] = output_path original_name, original_extension, _gz = parsed_path.name.split(".") output_uncompressed = Path.cwd() / f"{original_name}.{original_extension}" output_uncompressed_paths[url_type] = output_uncompressed for url_type, url in urls.items(): output_path = output_compressed_paths[url_type] if output_path.exists(): log.info("file %s already exists, ignoring", output_path) continue log.info("downloading %r into %s", url, output_path) async with ctx.session.get(url) as resp: assert resp.status == 200 total_length = int(resp.headers["content-length"]) downloaded_bytes = 0 download_ratio = 0 log.info("to download %d bytes", total_length) with tempfile.TemporaryFile() as temp_fd: async for chunk in resp.content.iter_chunked(8192): temp_fd.write(chunk) downloaded_bytes += len(chunk) new_download_ratio = round((downloaded_bytes / total_length) * 100) if new_download_ratio != download_ratio: log.info("download at %d%%", download_ratio) download_ratio = new_download_ratio temp_fd.seek(0) # write to output log.info("copying temp to output") with output_path.open(mode="wb") as output_fd: shutil.copyfileobj(temp_fd, output_fd) # decompress for url_type, _url in urls.items(): input_path = output_compressed_paths[url_type] output_path = output_uncompressed_paths[url_type] if output_path.exists(): log.info("decompressed file %s already exists, ignoring", output_path) continue log.info("decompressing %s into %s", input_path.name, output_path.name) with gzip.open(input_path, "rb") as in_fd: with output_path.open(mode="wb") as out_fd: shutil.copyfileobj(in_fd, out_fd) # now that everythings downloaded, compile the db await ctx.db.executescript( """ CREATE TABLE IF NOT EXISTS db_version ( version int primary key ) STRICT; CREATE TABLE IF NOT EXISTS tags ( id int primary key, name text unique, category int not null, post_count int not null ) STRICT; CREATE TABLE IF NOT EXISTS posts ( id int primary key, uploader_id int not null, created_at int not null, md5 text unique not null, source text not null, rating text not null, tag_string text not null, is_deleted int not null, is_pending int not null, is_flagged int not null, score int not null, up_score int not null, down_score int not null, is_rating_locked int not null ) STRICT; """ ) await ctx.db.commit() tag_count_rows = await ctx.db.execute_fetchall("select count(*) from tags") tag_count = tag_count_rows[0][0] log.info("already have %d tags", tag_count) with output_uncompressed_paths["tags"].open( mode="r", encoding="utf-8" ) as tags_csv_fd: line_count = 0 for line in tags_csv_fd: line_count += 1 line_count -= 1 # remove header log.info("%d tags to import", line_count) if line_count == tag_count: log.info("same counts, not going to reimport") else: tags_csv_fd.seek(0) tags_reader = csv.reader(tags_csv_fd) assert len(next(tags_reader)) == 4 processed_count = 0 processed_ratio = 0 for row in tags_reader: tag = Tag(int(row[0]), row[1], int(row[2]), int(row[3])) await ctx.db.execute( "insert into tags (id, name, category, post_count) values (?, ?, ?, ?)", (tag.id, tag.name, tag.category, tag.post_count), ) processed_count += 1 new_processed_ratio = round((processed_count / line_count) * 100) if new_processed_ratio != processed_ratio: log.info("tags processed at %d%%", processed_ratio) processed_ratio = new_processed_ratio log.info("tags done") await ctx.db.commit() log.info("going to process posts") post_count_rows = await ctx.db.execute_fetchall("select count(*) from posts") post_count = post_count_rows[0][0] log.info("already have %d posts", post_count) with output_uncompressed_paths["posts"].open( mode="r", encoding="utf-8" ) as posts_csv_fd: line_count = 0 counter_reader = csv.reader(posts_csv_fd) for _row in counter_reader: line_count += 1 line_count -= 1 # remove header log.info("%d posts to import", line_count) if line_count == post_count: log.info("already imported everything, skipping") else: posts_csv_fd.seek(0) posts_reader = csv.DictReader(posts_csv_fd) processed_count = 0 processed_ratio = 0.0 for row in posts_reader: created_at_str = row["created_at"] created_at = datetime.strptime( created_at_str[: created_at_str.find(".")], "%Y-%m-%d %H:%M:%S" ) post = Post( id=int(row["id"]), uploader_id=int(row["uploader_id"]), created_at=int(created_at.timestamp()), md5=row["md5"], source=row["source"], rating=row["rating"], tag_string=row["tag_string"], is_deleted=e621_bool(row["is_deleted"]), is_pending=e621_bool(row["is_pending"]), is_flagged=e621_bool(row["is_flagged"]), score=int(row["score"]), up_score=int(row["up_score"]), down_score=int(row["down_score"]), is_rating_locked=e621_bool(row["is_rating_locked"]), ) await ctx.db.execute( """ insert into posts ( id, uploader_id, created_at, md5, source, rating, tag_string, is_deleted, is_pending, is_flagged, score, up_score, down_score, is_rating_locked ) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?) """, ( post.id, post.uploader_id, post.created_at, post.md5, post.source, post.rating, post.tag_string, post.is_deleted, post.is_pending, post.is_flagged, post.score, post.up_score, post.down_score, post.is_rating_locked, ), ) processed_count += 1 new_processed_ratio = round((processed_count / line_count) * 100, 2) if str(new_processed_ratio) != str(processed_ratio): log.info("posts processed at %.2f%%", processed_ratio) processed_ratio = new_processed_ratio log.info("posts done") await ctx.db.commit() log.info("vacuuming db...") await ctx.db.execute("vacuum") log.info("database built") async def main(): wanted_date = sys.argv[1] async with aiohttp.ClientSession() as session: async with aiosqlite.connect("./e621.db") as db: ctx = Context(session, db) await main_with_ctx(ctx, wanted_date) if __name__ == "__main__": logging.basicConfig(level=logging.INFO) csv.field_size_limit(2**19) asyncio.run(main())