insert tags into sqlite db
This commit is contained in:
parent
58e9e4e9cb
commit
c987eb7035
1 changed files with 149 additions and 59 deletions
|
@ -1,4 +1,5 @@
|
|||
import asyncio
|
||||
import csv
|
||||
import gzip
|
||||
import logging
|
||||
import shutil
|
||||
|
@ -7,7 +8,9 @@ import sys
|
|||
from dataclasses import dataclass
|
||||
from urllib.parse import urlparse
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiosqlite
|
||||
import aiohttp
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -16,71 +19,158 @@ log = logging.getLogger(__name__)
|
|||
@dataclass
|
||||
class Context:
|
||||
session: aiohttp.ClientSession
|
||||
db: Any
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tag:
|
||||
id: int
|
||||
name: str
|
||||
category: int
|
||||
post_count: int
|
||||
|
||||
|
||||
async def main_with_ctx(ctx, wanted_date):
|
||||
urls = {
|
||||
"tags": f"https://e621.net/db_export/tags-{wanted_date}.csv.gz",
|
||||
# "posts": f"https://e621.net/db_export/posts-{wanted_date}.csv.gz",
|
||||
}
|
||||
|
||||
output_compressed_paths = {}
|
||||
output_uncompressed_paths = {}
|
||||
|
||||
for url_type, url in urls.items():
|
||||
parsed = urlparse(url)
|
||||
parsed_path = Path(parsed.path)
|
||||
output_path = Path.cwd() / parsed_path.name
|
||||
output_compressed_paths[url_type] = output_path
|
||||
|
||||
original_name, original_extension, _gz = parsed_path.name.split(".")
|
||||
output_uncompressed = Path.cwd() / f"{original_name}.{original_extension}"
|
||||
output_uncompressed_paths[url_type] = output_uncompressed
|
||||
|
||||
for url_type, url in urls.items():
|
||||
output_path = output_compressed_paths[url_type]
|
||||
if output_path.exists():
|
||||
log.info("file %s already exists, ignoring", output_path)
|
||||
continue
|
||||
|
||||
log.info("downloading %r into %s", url, output_path)
|
||||
async with ctx.session.get(url) as resp:
|
||||
assert resp.status == 200
|
||||
|
||||
total_length = int(resp.headers["content-length"])
|
||||
downloaded_bytes = 0
|
||||
download_ratio = 0
|
||||
|
||||
log.info("to download %d bytes", total_length)
|
||||
|
||||
with tempfile.TemporaryFile() as temp_fd:
|
||||
async for chunk in resp.content.iter_chunked(8192):
|
||||
temp_fd.write(chunk)
|
||||
downloaded_bytes += len(chunk)
|
||||
new_download_ratio = round((downloaded_bytes / total_length) * 100)
|
||||
if new_download_ratio != download_ratio:
|
||||
log.info("download at %d%%", download_ratio)
|
||||
download_ratio = new_download_ratio
|
||||
|
||||
temp_fd.seek(0)
|
||||
|
||||
# write to output
|
||||
log.info("copying temp to output")
|
||||
with output_path.open(mode="wb") as output_fd:
|
||||
shutil.copyfileobj(temp_fd, output_fd)
|
||||
|
||||
# decompress
|
||||
for url_type, _url in urls.items():
|
||||
input_path = output_compressed_paths[url_type]
|
||||
output_path = output_uncompressed_paths[url_type]
|
||||
|
||||
if output_path.exists():
|
||||
log.info("decompressed file %s already exists, ignoring", output_path)
|
||||
continue
|
||||
|
||||
log.info("decompressing %s into %s", input_path.name, output_path.name)
|
||||
with gzip.open(input_path, "rb") as in_fd:
|
||||
with output_path.open(mode="wb") as out_fd:
|
||||
shutil.copyfileobj(in_fd, out_fd)
|
||||
|
||||
# now that everythings downloaded, compile the db
|
||||
await ctx.db.executescript(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS db_version (
|
||||
version int primary key
|
||||
) STRICT;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS tags (
|
||||
id int primary key,
|
||||
name text unique,
|
||||
category int not null,
|
||||
post_count int not null
|
||||
) STRICT;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS posts (
|
||||
id int primary key,
|
||||
uploader_id int not null,
|
||||
created_at int not null,
|
||||
md5 text unique not null,
|
||||
source text not null,
|
||||
rating text not null,
|
||||
tag_string text not null,
|
||||
is_deleted int not null,
|
||||
is_pending int not null,
|
||||
is_flagged int not null,
|
||||
score int not null,
|
||||
up_score int not null,
|
||||
down_score int not null,
|
||||
is_rating_locked int not null
|
||||
) STRICT;
|
||||
"""
|
||||
)
|
||||
await ctx.db.commit()
|
||||
|
||||
with output_uncompressed_paths["tags"].open(
|
||||
mode="r", encoding="utf-8"
|
||||
) as tags_csv_fd:
|
||||
line_count = 0
|
||||
for line in tags_csv_fd:
|
||||
line_count += 1
|
||||
|
||||
line_count -= 1 # remove header
|
||||
|
||||
log.info("%d tags to import", line_count)
|
||||
tags_csv_fd.seek(0)
|
||||
tags_reader = csv.reader(tags_csv_fd)
|
||||
|
||||
assert len(next(tags_reader)) == 4
|
||||
|
||||
processed_count = 0
|
||||
processed_ratio = 0
|
||||
|
||||
for row in tags_reader:
|
||||
tag = Tag(int(row[0]), row[1], int(row[2]), int(row[3]))
|
||||
await ctx.db.execute(
|
||||
"insert into tags (id, name, category, post_count) values (?, ?, ?, ?)",
|
||||
(tag.id, tag.name, tag.category, tag.post_count),
|
||||
)
|
||||
processed_count += 1
|
||||
new_processed_ratio = round((processed_count / line_count) * 100)
|
||||
if new_processed_ratio != processed_ratio:
|
||||
log.info("processed at %d%%", processed_ratio)
|
||||
processed_ratio = new_processed_ratio
|
||||
|
||||
log.info("done")
|
||||
|
||||
await ctx.db.commit()
|
||||
|
||||
|
||||
async def main():
|
||||
wanted_date = sys.argv[1]
|
||||
urls = (
|
||||
f"https://e621.net/db_export/tags-{wanted_date}.csv.gz",
|
||||
f"https://e621.net/db_export/posts-{wanted_date}.csv.gz",
|
||||
)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
ctx = Context(session)
|
||||
|
||||
for url in urls:
|
||||
parsed = urlparse(url)
|
||||
parsed_path = Path(parsed.path)
|
||||
output_path = Path.cwd() / parsed_path.name
|
||||
if output_path.exists():
|
||||
log.info("file %s already exists, ignoring", output_path)
|
||||
continue
|
||||
|
||||
log.info("downloading %r into %s", url, output_path)
|
||||
async with ctx.session.get(url) as resp:
|
||||
assert resp.status == 200
|
||||
|
||||
total_length = int(resp.headers["content-length"])
|
||||
downloaded_bytes = 0
|
||||
download_ratio = 0
|
||||
|
||||
log.info("to download %d bytes", total_length)
|
||||
|
||||
with tempfile.TemporaryFile() as temp_fd:
|
||||
async for chunk in resp.content.iter_chunked(8192):
|
||||
temp_fd.write(chunk)
|
||||
downloaded_bytes += len(chunk)
|
||||
new_download_ratio = round(
|
||||
(downloaded_bytes / total_length) * 100
|
||||
)
|
||||
if new_download_ratio != download_ratio:
|
||||
log.info("download at %d%%", download_ratio)
|
||||
download_ratio = new_download_ratio
|
||||
|
||||
temp_fd.seek(0)
|
||||
|
||||
# write to output
|
||||
log.info("copying temp to output")
|
||||
with output_path.open(mode="wb") as output_fd:
|
||||
shutil.copyfileobj(temp_fd, output_fd)
|
||||
|
||||
# decompress
|
||||
for url in urls:
|
||||
parsed = urlparse(url)
|
||||
parsed_path = Path(parsed.path)
|
||||
input_path = Path.cwd() / parsed_path.name
|
||||
original_name, original_extension, _gz = parsed_path.name.split(".")
|
||||
output_path = Path.cwd() / f"{original_name}.{original_extension}"
|
||||
if output_path.exists():
|
||||
log.info("decompressed file %s already exists, ignoring", output_path)
|
||||
continue
|
||||
|
||||
log.info("decompressing %s into %s", input_path.name, output_path.name)
|
||||
with gzip.open(input_path, "rb") as in_fd:
|
||||
with output_path.open(mode="wb") as out_fd:
|
||||
shutil.copyfileobj(in_fd, out_fd)
|
||||
|
||||
# now that everythings downloaded, compile the db
|
||||
async with aiosqlite.connect("./e621.db") as db:
|
||||
ctx = Context(session, db)
|
||||
await main_with_ctx(ctx, wanted_date)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in a new issue