diff --git a/build_database.py b/build_database.py index c724fe6..67415c7 100644 --- a/build_database.py +++ b/build_database.py @@ -1,4 +1,5 @@ import asyncio +import csv import gzip import logging import shutil @@ -7,7 +8,9 @@ import sys from dataclasses import dataclass from urllib.parse import urlparse from pathlib import Path +from typing import Any +import aiosqlite import aiohttp log = logging.getLogger(__name__) @@ -16,71 +19,158 @@ log = logging.getLogger(__name__) @dataclass class Context: session: aiohttp.ClientSession + db: Any + + +@dataclass +class Tag: + id: int + name: str + category: int + post_count: int + + +async def main_with_ctx(ctx, wanted_date): + urls = { + "tags": f"https://e621.net/db_export/tags-{wanted_date}.csv.gz", + # "posts": f"https://e621.net/db_export/posts-{wanted_date}.csv.gz", + } + + output_compressed_paths = {} + output_uncompressed_paths = {} + + for url_type, url in urls.items(): + parsed = urlparse(url) + parsed_path = Path(parsed.path) + output_path = Path.cwd() / parsed_path.name + output_compressed_paths[url_type] = output_path + + original_name, original_extension, _gz = parsed_path.name.split(".") + output_uncompressed = Path.cwd() / f"{original_name}.{original_extension}" + output_uncompressed_paths[url_type] = output_uncompressed + + for url_type, url in urls.items(): + output_path = output_compressed_paths[url_type] + if output_path.exists(): + log.info("file %s already exists, ignoring", output_path) + continue + + log.info("downloading %r into %s", url, output_path) + async with ctx.session.get(url) as resp: + assert resp.status == 200 + + total_length = int(resp.headers["content-length"]) + downloaded_bytes = 0 + download_ratio = 0 + + log.info("to download %d bytes", total_length) + + with tempfile.TemporaryFile() as temp_fd: + async for chunk in resp.content.iter_chunked(8192): + temp_fd.write(chunk) + downloaded_bytes += len(chunk) + new_download_ratio = round((downloaded_bytes / total_length) * 100) + if new_download_ratio != download_ratio: + log.info("download at %d%%", download_ratio) + download_ratio = new_download_ratio + + temp_fd.seek(0) + + # write to output + log.info("copying temp to output") + with output_path.open(mode="wb") as output_fd: + shutil.copyfileobj(temp_fd, output_fd) + + # decompress + for url_type, _url in urls.items(): + input_path = output_compressed_paths[url_type] + output_path = output_uncompressed_paths[url_type] + + if output_path.exists(): + log.info("decompressed file %s already exists, ignoring", output_path) + continue + + log.info("decompressing %s into %s", input_path.name, output_path.name) + with gzip.open(input_path, "rb") as in_fd: + with output_path.open(mode="wb") as out_fd: + shutil.copyfileobj(in_fd, out_fd) + + # now that everythings downloaded, compile the db + await ctx.db.executescript( + """ + CREATE TABLE IF NOT EXISTS db_version ( + version int primary key + ) STRICT; + + CREATE TABLE IF NOT EXISTS tags ( + id int primary key, + name text unique, + category int not null, + post_count int not null + ) STRICT; + + CREATE TABLE IF NOT EXISTS posts ( + id int primary key, + uploader_id int not null, + created_at int not null, + md5 text unique not null, + source text not null, + rating text not null, + tag_string text not null, + is_deleted int not null, + is_pending int not null, + is_flagged int not null, + score int not null, + up_score int not null, + down_score int not null, + is_rating_locked int not null + ) STRICT; + """ + ) + await ctx.db.commit() + + with output_uncompressed_paths["tags"].open( + mode="r", encoding="utf-8" + ) as tags_csv_fd: + line_count = 0 + for line in tags_csv_fd: + line_count += 1 + + line_count -= 1 # remove header + + log.info("%d tags to import", line_count) + tags_csv_fd.seek(0) + tags_reader = csv.reader(tags_csv_fd) + + assert len(next(tags_reader)) == 4 + + processed_count = 0 + processed_ratio = 0 + + for row in tags_reader: + tag = Tag(int(row[0]), row[1], int(row[2]), int(row[3])) + await ctx.db.execute( + "insert into tags (id, name, category, post_count) values (?, ?, ?, ?)", + (tag.id, tag.name, tag.category, tag.post_count), + ) + processed_count += 1 + new_processed_ratio = round((processed_count / line_count) * 100) + if new_processed_ratio != processed_ratio: + log.info("processed at %d%%", processed_ratio) + processed_ratio = new_processed_ratio + + log.info("done") + + await ctx.db.commit() async def main(): wanted_date = sys.argv[1] - urls = ( - f"https://e621.net/db_export/tags-{wanted_date}.csv.gz", - f"https://e621.net/db_export/posts-{wanted_date}.csv.gz", - ) async with aiohttp.ClientSession() as session: - ctx = Context(session) - - for url in urls: - parsed = urlparse(url) - parsed_path = Path(parsed.path) - output_path = Path.cwd() / parsed_path.name - if output_path.exists(): - log.info("file %s already exists, ignoring", output_path) - continue - - log.info("downloading %r into %s", url, output_path) - async with ctx.session.get(url) as resp: - assert resp.status == 200 - - total_length = int(resp.headers["content-length"]) - downloaded_bytes = 0 - download_ratio = 0 - - log.info("to download %d bytes", total_length) - - with tempfile.TemporaryFile() as temp_fd: - async for chunk in resp.content.iter_chunked(8192): - temp_fd.write(chunk) - downloaded_bytes += len(chunk) - new_download_ratio = round( - (downloaded_bytes / total_length) * 100 - ) - if new_download_ratio != download_ratio: - log.info("download at %d%%", download_ratio) - download_ratio = new_download_ratio - - temp_fd.seek(0) - - # write to output - log.info("copying temp to output") - with output_path.open(mode="wb") as output_fd: - shutil.copyfileobj(temp_fd, output_fd) - - # decompress - for url in urls: - parsed = urlparse(url) - parsed_path = Path(parsed.path) - input_path = Path.cwd() / parsed_path.name - original_name, original_extension, _gz = parsed_path.name.split(".") - output_path = Path.cwd() / f"{original_name}.{original_extension}" - if output_path.exists(): - log.info("decompressed file %s already exists, ignoring", output_path) - continue - - log.info("decompressing %s into %s", input_path.name, output_path.name) - with gzip.open(input_path, "rb") as in_fd: - with output_path.open(mode="wb") as out_fd: - shutil.copyfileobj(in_fd, out_fd) - - # now that everythings downloaded, compile the db + async with aiosqlite.connect("./e621.db") as db: + ctx = Context(session, db) + await main_with_ctx(ctx, wanted_date) if __name__ == "__main__":