e621_api_cloner/build_database.py

299 lines
9 KiB
Python
Raw Normal View History

2022-08-28 02:56:24 +00:00
import asyncio
2022-08-28 19:18:54 +00:00
import subprocess
2022-08-28 04:07:51 +00:00
import csv
2022-08-28 03:24:38 +00:00
import gzip
2022-08-28 02:56:24 +00:00
import logging
import shutil
import tempfile
import sys
from dataclasses import dataclass
from urllib.parse import urlparse
2022-08-28 19:13:26 +00:00
from datetime import datetime
2022-08-28 02:56:24 +00:00
from pathlib import Path
2022-08-28 04:07:51 +00:00
from typing import Any
2022-08-28 02:56:24 +00:00
2022-08-28 04:07:51 +00:00
import aiosqlite
2022-08-28 02:56:24 +00:00
import aiohttp
log = logging.getLogger(__name__)
@dataclass
class Context:
session: aiohttp.ClientSession
2022-08-28 04:07:51 +00:00
db: Any
@dataclass
class Tag:
id: int
name: str
category: int
post_count: int
2022-08-28 19:13:26 +00:00
@dataclass
class Post:
id: int
uploader_id: int
created_at: int
md5: str
source: str
rating: str
tag_string: str
is_deleted: int
is_pending: int
is_flagged: int
score: int
up_score: int
down_score: int
is_rating_locked: int
def e621_bool(text: str) -> bool:
return text == "t"
2022-08-28 04:07:51 +00:00
async def main_with_ctx(ctx, wanted_date):
urls = {
"tags": f"https://e621.net/db_export/tags-{wanted_date}.csv.gz",
2022-08-28 19:13:52 +00:00
"posts": f"https://e621.net/db_export/posts-{wanted_date}.csv.gz",
2022-08-28 04:07:51 +00:00
}
output_compressed_paths = {}
output_uncompressed_paths = {}
for url_type, url in urls.items():
parsed = urlparse(url)
parsed_path = Path(parsed.path)
output_path = Path.cwd() / parsed_path.name
output_compressed_paths[url_type] = output_path
original_name, original_extension, _gz = parsed_path.name.split(".")
output_uncompressed = Path.cwd() / f"{original_name}.{original_extension}"
output_uncompressed_paths[url_type] = output_uncompressed
for url_type, url in urls.items():
output_path = output_compressed_paths[url_type]
if output_path.exists():
log.info("file %s already exists, ignoring", output_path)
continue
log.info("downloading %r into %s", url, output_path)
async with ctx.session.get(url) as resp:
assert resp.status == 200
total_length = int(resp.headers["content-length"])
downloaded_bytes = 0
download_ratio = 0
log.info("to download %d bytes", total_length)
with tempfile.TemporaryFile() as temp_fd:
async for chunk in resp.content.iter_chunked(8192):
temp_fd.write(chunk)
downloaded_bytes += len(chunk)
new_download_ratio = round((downloaded_bytes / total_length) * 100)
if new_download_ratio != download_ratio:
log.info("download at %d%%", download_ratio)
download_ratio = new_download_ratio
temp_fd.seek(0)
# write to output
log.info("copying temp to output")
with output_path.open(mode="wb") as output_fd:
shutil.copyfileobj(temp_fd, output_fd)
# decompress
for url_type, _url in urls.items():
input_path = output_compressed_paths[url_type]
output_path = output_uncompressed_paths[url_type]
if output_path.exists():
log.info("decompressed file %s already exists, ignoring", output_path)
continue
log.info("decompressing %s into %s", input_path.name, output_path.name)
with gzip.open(input_path, "rb") as in_fd:
with output_path.open(mode="wb") as out_fd:
shutil.copyfileobj(in_fd, out_fd)
# now that everythings downloaded, compile the db
await ctx.db.executescript(
"""
CREATE TABLE IF NOT EXISTS db_version (
version int primary key
) STRICT;
CREATE TABLE IF NOT EXISTS tags (
id int primary key,
name text unique,
category int not null,
post_count int not null
) STRICT;
CREATE TABLE IF NOT EXISTS posts (
id int primary key,
uploader_id int not null,
created_at int not null,
md5 text unique not null,
source text not null,
rating text not null,
tag_string text not null,
is_deleted int not null,
is_pending int not null,
is_flagged int not null,
score int not null,
up_score int not null,
down_score int not null,
is_rating_locked int not null
) STRICT;
"""
)
await ctx.db.commit()
2022-08-28 19:13:26 +00:00
tag_count_rows = await ctx.db.execute_fetchall("select count(*) from tags")
tag_count = tag_count_rows[0][0]
log.info("already have %d tags", tag_count)
2022-08-28 04:07:51 +00:00
with output_uncompressed_paths["tags"].open(
mode="r", encoding="utf-8"
) as tags_csv_fd:
line_count = 0
for line in tags_csv_fd:
line_count += 1
line_count -= 1 # remove header
log.info("%d tags to import", line_count)
2022-08-28 19:13:26 +00:00
if line_count == tag_count:
log.info("same counts, not going to reimport")
else:
tags_csv_fd.seek(0)
tags_reader = csv.reader(tags_csv_fd)
assert len(next(tags_reader)) == 4
processed_count = 0
processed_ratio = 0
for row in tags_reader:
tag = Tag(int(row[0]), row[1], int(row[2]), int(row[3]))
await ctx.db.execute(
"insert into tags (id, name, category, post_count) values (?, ?, ?, ?)",
(tag.id, tag.name, tag.category, tag.post_count),
)
processed_count += 1
new_processed_ratio = round((processed_count / line_count) * 100)
if new_processed_ratio != processed_ratio:
log.info("tags processed at %d%%", processed_ratio)
processed_ratio = new_processed_ratio
log.info("tags done")
await ctx.db.commit()
2022-08-28 19:17:17 +00:00
log.info("going to process posts")
2022-08-28 19:13:26 +00:00
with output_uncompressed_paths["posts"].open(
mode="r", encoding="utf-8"
) as posts_csv_fd:
line_count = 0
counter_reader = csv.reader(posts_csv_fd)
for _row in counter_reader:
line_count += 1
2022-08-28 19:13:26 +00:00
line_count -= 1 # remove header
log.info("%d posts to import", line_count)
posts_csv_fd.seek(0)
posts_reader = csv.DictReader(posts_csv_fd)
2022-08-28 04:07:51 +00:00
processed_count = 0
2022-08-28 19:17:17 +00:00
processed_ratio = 0.0
2022-08-28 04:07:51 +00:00
2022-08-28 19:13:26 +00:00
for row in posts_reader:
created_at_str = row["created_at"]
created_at = datetime.strptime(
created_at_str[: created_at_str.find(".")], "%Y-%m-%d %H:%M:%S"
)
post = Post(
id=int(row["id"]),
uploader_id=int(row["uploader_id"]),
created_at=int(created_at.timestamp()),
md5=row["md5"],
source=row["source"],
rating=row["rating"],
tag_string=row["tag_string"],
is_deleted=e621_bool(row["is_deleted"]),
is_pending=e621_bool(row["is_pending"]),
is_flagged=e621_bool(row["is_flagged"]),
score=int(row["score"]),
up_score=int(row["up_score"]),
down_score=int(row["down_score"]),
is_rating_locked=e621_bool(row["is_rating_locked"]),
)
2022-08-28 04:07:51 +00:00
await ctx.db.execute(
2022-08-28 19:13:26 +00:00
"""
insert into posts (
id,
uploader_id,
created_at,
md5,
source,
rating,
tag_string,
is_deleted,
is_pending,
is_flagged,
score,
up_score,
down_score,
is_rating_locked
) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?)
""",
(
post.id,
post.uploader_id,
post.created_at,
post.md5,
post.source,
post.rating,
post.tag_string,
post.is_deleted,
post.is_pending,
post.is_flagged,
post.score,
post.up_score,
post.down_score,
post.is_rating_locked,
),
2022-08-28 04:07:51 +00:00
)
processed_count += 1
2022-08-28 19:17:17 +00:00
new_processed_ratio = round((processed_count / line_count) * 100, 2)
if str(new_processed_ratio) != str(processed_ratio):
log.info("posts processed at %.2f%%", processed_ratio)
2022-08-28 04:07:51 +00:00
processed_ratio = new_processed_ratio
2022-08-28 19:13:26 +00:00
log.info("posts done")
2022-08-28 04:07:51 +00:00
await ctx.db.commit()
2022-08-28 02:56:24 +00:00
async def main():
wanted_date = sys.argv[1]
async with aiohttp.ClientSession() as session:
2022-08-28 04:07:51 +00:00
async with aiosqlite.connect("./e621.db") as db:
ctx = Context(session, db)
await main_with_ctx(ctx, wanted_date)
2022-08-28 02:56:24 +00:00
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
2022-08-28 19:24:44 +00:00
csv.field_size_limit(2**19)
2022-08-28 02:56:24 +00:00
asyncio.run(main())