e621_api_cloner/build_database.py
2022-10-04 14:03:52 -03:00

313 lines
9.8 KiB
Python

import asyncio
import subprocess
import csv
import gzip
import logging
import shutil
import tempfile
import sys
from dataclasses import dataclass
from urllib.parse import urlparse
from datetime import datetime
from pathlib import Path
from typing import Any
import aiosqlite
import aiohttp
log = logging.getLogger(__name__)
@dataclass
class Context:
session: aiohttp.ClientSession
db: Any
@dataclass
class Tag:
id: int
name: str
category: int
post_count: int
@dataclass
class Post:
id: int
uploader_id: int
created_at: int
md5: str
source: str
rating: str
tag_string: str
is_deleted: int
is_pending: int
is_flagged: int
score: int
up_score: int
down_score: int
is_rating_locked: int
def e621_bool(text: str) -> bool:
return text == "t"
async def main_with_ctx(ctx, wanted_date):
urls = {
"tags": f"https://e621.net/db_export/tags-{wanted_date}.csv.gz",
"posts": f"https://e621.net/db_export/posts-{wanted_date}.csv.gz",
}
output_compressed_paths = {}
output_uncompressed_paths = {}
for url_type, url in urls.items():
parsed = urlparse(url)
parsed_path = Path(parsed.path)
output_path = Path.cwd() / parsed_path.name
output_compressed_paths[url_type] = output_path
original_name, original_extension, _gz = parsed_path.name.split(".")
output_uncompressed = Path.cwd() / f"{original_name}.{original_extension}"
output_uncompressed_paths[url_type] = output_uncompressed
for url_type, url in urls.items():
output_path = output_compressed_paths[url_type]
if output_path.exists():
log.info("file %s already exists, ignoring", output_path)
continue
log.info("downloading %r into %s", url, output_path)
async with ctx.session.get(url) as resp:
assert resp.status == 200
total_length = int(resp.headers["content-length"])
downloaded_bytes = 0
download_ratio = 0
log.info("to download %d bytes", total_length)
with tempfile.TemporaryFile() as temp_fd:
async for chunk in resp.content.iter_chunked(8192):
temp_fd.write(chunk)
downloaded_bytes += len(chunk)
new_download_ratio = round((downloaded_bytes / total_length) * 100)
if new_download_ratio != download_ratio:
log.info("download at %d%%", download_ratio)
download_ratio = new_download_ratio
temp_fd.seek(0)
# write to output
log.info("copying temp to output")
with output_path.open(mode="wb") as output_fd:
shutil.copyfileobj(temp_fd, output_fd)
# decompress
for url_type, _url in urls.items():
input_path = output_compressed_paths[url_type]
output_path = output_uncompressed_paths[url_type]
if output_path.exists():
log.info("decompressed file %s already exists, ignoring", output_path)
continue
log.info("decompressing %s into %s", input_path.name, output_path.name)
with gzip.open(input_path, "rb") as in_fd:
with output_path.open(mode="wb") as out_fd:
shutil.copyfileobj(in_fd, out_fd)
# now that everythings downloaded, compile the db
await ctx.db.executescript(
"""
CREATE TABLE IF NOT EXISTS db_version (
version int primary key
) STRICT;
CREATE TABLE IF NOT EXISTS tags (
id int primary key,
name text unique,
category int not null,
post_count int not null
) STRICT;
CREATE TABLE IF NOT EXISTS posts (
id int primary key,
uploader_id int not null,
created_at int not null,
md5 text unique not null,
source text not null,
rating text not null,
tag_string text not null,
is_deleted int not null,
is_pending int not null,
is_flagged int not null,
score int not null,
up_score int not null,
down_score int not null,
is_rating_locked int not null
) STRICT;
"""
)
await ctx.db.commit()
tag_count_rows = await ctx.db.execute_fetchall("select count(*) from tags")
tag_count = tag_count_rows[0][0]
log.info("already have %d tags", tag_count)
work_done = False
with output_uncompressed_paths["tags"].open(
mode="r", encoding="utf-8"
) as tags_csv_fd:
line_count = 0
for line in tags_csv_fd:
line_count += 1
line_count -= 1 # remove header
log.info("%d tags to import", line_count)
if line_count == tag_count:
log.info("same counts, not going to reimport")
else:
tags_csv_fd.seek(0)
tags_reader = csv.reader(tags_csv_fd)
assert len(next(tags_reader)) == 4
processed_count = 0
processed_ratio = 0
for row in tags_reader:
tag = Tag(int(row[0]), row[1], int(row[2]), int(row[3]))
await ctx.db.execute(
"insert into tags (id, name, category, post_count) values (?, ?, ?, ?) on conflict do nothing",
(tag.id, tag.name, tag.category, tag.post_count),
)
processed_count += 1
new_processed_ratio = round((processed_count / line_count) * 100)
if new_processed_ratio != processed_ratio:
log.info("tags processed at %d%%", processed_ratio)
processed_ratio = new_processed_ratio
log.info("tags done")
work_done = True
await ctx.db.commit()
log.info("going to process posts")
post_count_rows = await ctx.db.execute_fetchall("select count(*) from posts")
post_count = post_count_rows[0][0]
log.info("already have %d posts", post_count)
with output_uncompressed_paths["posts"].open(
mode="r", encoding="utf-8"
) as posts_csv_fd:
line_count = 0
counter_reader = csv.reader(posts_csv_fd)
for _row in counter_reader:
line_count += 1
line_count -= 1 # remove header
log.info("%d posts to import", line_count)
if line_count == post_count:
log.info("already imported everything, skipping")
else:
posts_csv_fd.seek(0)
posts_reader = csv.DictReader(posts_csv_fd)
processed_count = 0
processed_ratio = 0.0
for row in posts_reader:
created_at_str = row["created_at"]
created_at = datetime.strptime(
created_at_str[: created_at_str.find(".")], "%Y-%m-%d %H:%M:%S"
)
post = Post(
id=int(row["id"]),
uploader_id=int(row["uploader_id"]),
created_at=int(created_at.timestamp()),
md5=row["md5"],
source=row["source"],
rating=row["rating"],
tag_string=row["tag_string"],
is_deleted=e621_bool(row["is_deleted"]),
is_pending=e621_bool(row["is_pending"]),
is_flagged=e621_bool(row["is_flagged"]),
score=int(row["score"]),
up_score=int(row["up_score"]),
down_score=int(row["down_score"]),
is_rating_locked=e621_bool(row["is_rating_locked"]),
)
await ctx.db.execute(
"""
insert into posts (
id,
uploader_id,
created_at,
md5,
source,
rating,
tag_string,
is_deleted,
is_pending,
is_flagged,
score,
up_score,
down_score,
is_rating_locked
) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?) on conflict do nothing
""",
(
post.id,
post.uploader_id,
post.created_at,
post.md5,
post.source,
post.rating,
post.tag_string,
post.is_deleted,
post.is_pending,
post.is_flagged,
post.score,
post.up_score,
post.down_score,
post.is_rating_locked,
),
)
processed_count += 1
new_processed_ratio = round((processed_count / line_count) * 100, 2)
if str(new_processed_ratio) != str(processed_ratio):
log.info("posts processed at %.2f%%", processed_ratio)
processed_ratio = new_processed_ratio
log.info("posts done")
work_done = True
await ctx.db.commit()
if work_done:
log.info("vacuuming db...")
await ctx.db.execute("vacuum")
log.info("database built")
async def main():
wanted_date = sys.argv[1]
async with aiohttp.ClientSession() as session:
async with aiosqlite.connect("./e621.db") as db:
ctx = Context(session, db)
await main_with_ctx(ctx, wanted_date)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
csv.field_size_limit(2**19)
asyncio.run(main())