diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..3bb011e --- /dev/null +++ b/.dockerignore @@ -0,0 +1,5 @@ +*.csv.gz +*.csv +*.db +env/ +authfile.example diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..e456989 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,8 @@ +FROM python:3.10-alpine +RUN apk add sqlite +ADD . ./e621_api_cloner +RUN pip3 install -Ur ./e621_api_cloner/requirements.txt +RUN touch ./e621_api_cloner/e621.db +EXPOSE 1337 +WORKDIR /e621_api_cloner +CMD ["hypercorn", "--access-log", "-", "-b", "0.0.0.0:1337", "e621_api_cloner:app"] diff --git a/authfile.example b/authfile.example new file mode 100644 index 0000000..2aafb8f --- /dev/null +++ b/authfile.example @@ -0,0 +1,2 @@ +5a8649fb426e13c080e118332f44b3adb70cc9c677fba58c3d018313c4b0ad67 my silly tool account 1 +504054a7c35af520ab0ddd14e1ad257633fe75ba8601915541abba6cec1a81f7 my silly tool account 2 diff --git a/build_database.py b/build_database.py index 4136be5..8eb933f 100644 --- a/build_database.py +++ b/build_database.py @@ -196,6 +196,9 @@ async def main_with_ctx(ctx, wanted_date): await ctx.db.commit() log.info("going to process posts") + post_count_rows = await ctx.db.execute_fetchall("select count(*) from posts") + post_count = post_count_rows[0][0] + log.info("already have %d posts", post_count) with output_uncompressed_paths["posts"].open( mode="r", encoding="utf-8" @@ -207,80 +210,87 @@ async def main_with_ctx(ctx, wanted_date): line_count -= 1 # remove header log.info("%d posts to import", line_count) - posts_csv_fd.seek(0) - posts_reader = csv.DictReader(posts_csv_fd) + if line_count == post_count: + log.info("already imported everything, skipping") + else: + posts_csv_fd.seek(0) + posts_reader = csv.DictReader(posts_csv_fd) - processed_count = 0 - processed_ratio = 0.0 + processed_count = 0 + processed_ratio = 0.0 - for row in posts_reader: - created_at_str = row["created_at"] - created_at = datetime.strptime( - created_at_str[: created_at_str.find(".")], "%Y-%m-%d %H:%M:%S" - ) + for row in posts_reader: + created_at_str = row["created_at"] + created_at = datetime.strptime( + created_at_str[: created_at_str.find(".")], "%Y-%m-%d %H:%M:%S" + ) - post = Post( - id=int(row["id"]), - uploader_id=int(row["uploader_id"]), - created_at=int(created_at.timestamp()), - md5=row["md5"], - source=row["source"], - rating=row["rating"], - tag_string=row["tag_string"], - is_deleted=e621_bool(row["is_deleted"]), - is_pending=e621_bool(row["is_pending"]), - is_flagged=e621_bool(row["is_flagged"]), - score=int(row["score"]), - up_score=int(row["up_score"]), - down_score=int(row["down_score"]), - is_rating_locked=e621_bool(row["is_rating_locked"]), - ) + post = Post( + id=int(row["id"]), + uploader_id=int(row["uploader_id"]), + created_at=int(created_at.timestamp()), + md5=row["md5"], + source=row["source"], + rating=row["rating"], + tag_string=row["tag_string"], + is_deleted=e621_bool(row["is_deleted"]), + is_pending=e621_bool(row["is_pending"]), + is_flagged=e621_bool(row["is_flagged"]), + score=int(row["score"]), + up_score=int(row["up_score"]), + down_score=int(row["down_score"]), + is_rating_locked=e621_bool(row["is_rating_locked"]), + ) - await ctx.db.execute( - """ - insert into posts ( - id, - uploader_id, - created_at, - md5, - source, - rating, - tag_string, - is_deleted, - is_pending, - is_flagged, - score, - up_score, - down_score, - is_rating_locked - ) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?) - """, - ( - post.id, - post.uploader_id, - post.created_at, - post.md5, - post.source, - post.rating, - post.tag_string, - post.is_deleted, - post.is_pending, - post.is_flagged, - post.score, - post.up_score, - post.down_score, - post.is_rating_locked, - ), - ) - processed_count += 1 - new_processed_ratio = round((processed_count / line_count) * 100, 2) - if str(new_processed_ratio) != str(processed_ratio): - log.info("posts processed at %.2f%%", processed_ratio) - processed_ratio = new_processed_ratio + await ctx.db.execute( + """ + insert into posts ( + id, + uploader_id, + created_at, + md5, + source, + rating, + tag_string, + is_deleted, + is_pending, + is_flagged, + score, + up_score, + down_score, + is_rating_locked + ) values (?,?,?,?,?,?,?,?,?,?,?,?,?,?) + """, + ( + post.id, + post.uploader_id, + post.created_at, + post.md5, + post.source, + post.rating, + post.tag_string, + post.is_deleted, + post.is_pending, + post.is_flagged, + post.score, + post.up_score, + post.down_score, + post.is_rating_locked, + ), + ) + processed_count += 1 + new_processed_ratio = round((processed_count / line_count) * 100, 2) + if str(new_processed_ratio) != str(processed_ratio): + log.info("posts processed at %.2f%%", processed_ratio) + processed_ratio = new_processed_ratio - log.info("posts done") + log.info("posts done") - await ctx.db.commit() + await ctx.db.commit() + + log.info("vacuuming db...") + await ctx.db.execute("vacuum") + log.info("database built") async def main(): diff --git a/e621_api_cloner.py b/e621_api_cloner.py index 5c5749b..79b8c09 100644 --- a/e621_api_cloner.py +++ b/e621_api_cloner.py @@ -7,6 +7,7 @@ import sqlite3 import sys import os import enum +from pathlib import Path from datetime import datetime from dataclasses import dataclass, asdict from hypercorn.asyncio import serve, Config @@ -19,13 +20,32 @@ app = Quart(__name__) @app.before_serving async def app_before_serving(): + logging.basicConfig( + level=logging.DEBUG if os.environ.get("DEBUG") else logging.INFO + ) + db_path = os.environ["DBPATH"] app.db = await aiosqlite.connect(db_path) app.db.row_factory = sqlite3.Row + # contains api keys + maybe_authfile_path = os.environ.get("AUTHFILE") + app.maybe_authfile = Path(maybe_authfile_path) if maybe_authfile_path else None + if app.maybe_authfile: + app.apikeys = {} + log.info("loading auth with api keys") + with app.maybe_authfile.open(mode="r") as fd: + for line in fd: + api_key, *user_name = line.split(" ") + app.apikeys[api_key] = (" ".join(user_name)).strip() + @app.after_serving async def app_after_serving(): + log.info("possibly optimizing database") + await app.db.execute("PRAGMA analysis_limit=400") + await app.db.execute("PRAGMA optimize") + log.info("closing connection") await app.db.close() @@ -96,8 +116,34 @@ class TagCategory(enum.IntEnum): }[self] +async def maybe_do_authentication(): + if not app.maybe_authfile: + return None + + auth_line = request.headers.get("authorization") + if not auth_line: + return "authorization header required", 400 + + auth_type, auth_data = auth_line.split(" ") + if auth_type != "Bearer": + log.warn("invalid auth type") + return "invalid auth token type (must be Bearer)", 400 + + auth_name = app.apikeys.get(auth_data) + if auth_name is None: + log.warn("invalid auth value") + return "invalid auth token (unknown api key)", 400 + + log.info("logged in as %r", auth_name) + return None + + @app.route("/posts.json") async def posts_json(): + res = await maybe_do_authentication() + if res: + return res + tag_str = request.args["tags"] tags = tag_str.split(" ") if len(tags) != 1: @@ -130,12 +176,3 @@ async def posts_json(): post_json["tags"][category_str].append(tag) return {"posts": [post_json]} - - -if __name__ == "__main__": - logging.basicConfig( - level=logging.DEBUG if os.environ.get("DEBUG") else logging.INFO - ) - config = Config() - config.bind = ["0.0.0.0:1334"] - asyncio.run(serve(app, config))