insert tags into sqlite db

This commit is contained in:
Luna 2022-08-28 01:07:51 -03:00
parent 58e9e4e9cb
commit c987eb7035
1 changed files with 149 additions and 59 deletions

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import csv
import gzip import gzip
import logging import logging
import shutil import shutil
@ -7,7 +8,9 @@ import sys
from dataclasses import dataclass from dataclasses import dataclass
from urllib.parse import urlparse from urllib.parse import urlparse
from pathlib import Path from pathlib import Path
from typing import Any
import aiosqlite
import aiohttp import aiohttp
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -16,22 +19,38 @@ log = logging.getLogger(__name__)
@dataclass @dataclass
class Context: class Context:
session: aiohttp.ClientSession session: aiohttp.ClientSession
db: Any
async def main(): @dataclass
wanted_date = sys.argv[1] class Tag:
urls = ( id: int
f"https://e621.net/db_export/tags-{wanted_date}.csv.gz", name: str
f"https://e621.net/db_export/posts-{wanted_date}.csv.gz", category: int
) post_count: int
async with aiohttp.ClientSession() as session:
ctx = Context(session)
for url in urls: async def main_with_ctx(ctx, wanted_date):
urls = {
"tags": f"https://e621.net/db_export/tags-{wanted_date}.csv.gz",
# "posts": f"https://e621.net/db_export/posts-{wanted_date}.csv.gz",
}
output_compressed_paths = {}
output_uncompressed_paths = {}
for url_type, url in urls.items():
parsed = urlparse(url) parsed = urlparse(url)
parsed_path = Path(parsed.path) parsed_path = Path(parsed.path)
output_path = Path.cwd() / parsed_path.name output_path = Path.cwd() / parsed_path.name
output_compressed_paths[url_type] = output_path
original_name, original_extension, _gz = parsed_path.name.split(".")
output_uncompressed = Path.cwd() / f"{original_name}.{original_extension}"
output_uncompressed_paths[url_type] = output_uncompressed
for url_type, url in urls.items():
output_path = output_compressed_paths[url_type]
if output_path.exists(): if output_path.exists():
log.info("file %s already exists, ignoring", output_path) log.info("file %s already exists, ignoring", output_path)
continue continue
@ -50,9 +69,7 @@ async def main():
async for chunk in resp.content.iter_chunked(8192): async for chunk in resp.content.iter_chunked(8192):
temp_fd.write(chunk) temp_fd.write(chunk)
downloaded_bytes += len(chunk) downloaded_bytes += len(chunk)
new_download_ratio = round( new_download_ratio = round((downloaded_bytes / total_length) * 100)
(downloaded_bytes / total_length) * 100
)
if new_download_ratio != download_ratio: if new_download_ratio != download_ratio:
log.info("download at %d%%", download_ratio) log.info("download at %d%%", download_ratio)
download_ratio = new_download_ratio download_ratio = new_download_ratio
@ -65,12 +82,10 @@ async def main():
shutil.copyfileobj(temp_fd, output_fd) shutil.copyfileobj(temp_fd, output_fd)
# decompress # decompress
for url in urls: for url_type, _url in urls.items():
parsed = urlparse(url) input_path = output_compressed_paths[url_type]
parsed_path = Path(parsed.path) output_path = output_uncompressed_paths[url_type]
input_path = Path.cwd() / parsed_path.name
original_name, original_extension, _gz = parsed_path.name.split(".")
output_path = Path.cwd() / f"{original_name}.{original_extension}"
if output_path.exists(): if output_path.exists():
log.info("decompressed file %s already exists, ignoring", output_path) log.info("decompressed file %s already exists, ignoring", output_path)
continue continue
@ -81,6 +96,81 @@ async def main():
shutil.copyfileobj(in_fd, out_fd) shutil.copyfileobj(in_fd, out_fd)
# now that everythings downloaded, compile the db # now that everythings downloaded, compile the db
await ctx.db.executescript(
"""
CREATE TABLE IF NOT EXISTS db_version (
version int primary key
) STRICT;
CREATE TABLE IF NOT EXISTS tags (
id int primary key,
name text unique,
category int not null,
post_count int not null
) STRICT;
CREATE TABLE IF NOT EXISTS posts (
id int primary key,
uploader_id int not null,
created_at int not null,
md5 text unique not null,
source text not null,
rating text not null,
tag_string text not null,
is_deleted int not null,
is_pending int not null,
is_flagged int not null,
score int not null,
up_score int not null,
down_score int not null,
is_rating_locked int not null
) STRICT;
"""
)
await ctx.db.commit()
with output_uncompressed_paths["tags"].open(
mode="r", encoding="utf-8"
) as tags_csv_fd:
line_count = 0
for line in tags_csv_fd:
line_count += 1
line_count -= 1 # remove header
log.info("%d tags to import", line_count)
tags_csv_fd.seek(0)
tags_reader = csv.reader(tags_csv_fd)
assert len(next(tags_reader)) == 4
processed_count = 0
processed_ratio = 0
for row in tags_reader:
tag = Tag(int(row[0]), row[1], int(row[2]), int(row[3]))
await ctx.db.execute(
"insert into tags (id, name, category, post_count) values (?, ?, ?, ?)",
(tag.id, tag.name, tag.category, tag.post_count),
)
processed_count += 1
new_processed_ratio = round((processed_count / line_count) * 100)
if new_processed_ratio != processed_ratio:
log.info("processed at %d%%", processed_ratio)
processed_ratio = new_processed_ratio
log.info("done")
await ctx.db.commit()
async def main():
wanted_date = sys.argv[1]
async with aiohttp.ClientSession() as session:
async with aiosqlite.connect("./e621.db") as db:
ctx = Context(session, db)
await main_with_ctx(ctx, wanted_date)
if __name__ == "__main__": if __name__ == "__main__":