diff --git a/e621_api_cloner.py b/e621_api_cloner.py new file mode 100644 index 0000000..62971c5 --- /dev/null +++ b/e621_api_cloner.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python3 + +import asyncio +import logging +import aiosqlite +import sqlite3 +import sys +import os +import enum +from datetime import datetime +from dataclasses import dataclass, asdict +from hypercorn.asyncio import serve, Config +from quart import Quart, request, jsonify + + +log = logging.getLogger(__name__) +app = Quart(__name__) + + +@app.before_serving +async def app_before_serving(): + db_path = sys.argv[1] + app.db = await aiosqlite.connect(db_path) + app.db.row_factory = sqlite3.Row + + +@app.after_serving +async def app_after_serving(): + await app.db.close() + + +@dataclass +class Tag: + id: int + name: str + category: int + post_count: int + + +@dataclass +class Post: + id: int + uploader_id: int + created_at: int + md5: str + source: str + rating: str + tag_string: str + is_deleted: int + is_pending: int + is_flagged: int + score: int + up_score: int + down_score: int + is_rating_locked: int + + def to_json(self): + post = asdict(self) + post["created_at"] = datetime.fromtimestamp(self.created_at).isoformat() + post["score"] = { + "up": self.up_score, + "down": self.down_score, + "total": self.score, + } + post.pop("up_score") + post.pop("down_score") + post["flags"] = { + "pending": post.pop("is_pending"), + "flagged": post.pop("is_flagged"), + "deleted": post.pop("is_deleted"), + "rating_locked": post.pop("is_rating_locked"), + } + return post + + +class TagCategory(enum.IntEnum): + GENERAL = 0 + ARTIST = 1 + COPYRIGHT = 3 + CHARACTER = 4 + SPECIES = 5 + DEPRECATED = 6 + METADATA = 7 + LORE = 8 + + def to_string(self) -> str: + return { + self.GENERAL: "general", + self.ARTIST: "artist", + self.COPYRIGHT: "copyright", + self.CHARACTER: "character", + self.METADATA: "metadata", + self.DEPRECATED: "deprecated", + self.SPECIES: "species", + self.LORE: "lore", + }[self] + + +@app.route("/posts.json") +async def posts_json(): + tag_str = request.args["tags"] + tags = tag_str.split(" ") + if len(tags) != 1: + return "unsupported query (only 1 tag)", 400 + + md5_tag = tags[0] + if not md5_tag.startswith("md5:"): + return "unsupported query (only md5 for now)", 400 + + md5_prefix, md5_hexencoded_hash = md5_tag.split(":") + assert md5_prefix == "md5" + + rows = await app.db.execute_fetchall( + "select * from posts where md5 = ?", (md5_hexencoded_hash,) + ) + if not rows: + return {"posts": []} + + post = Post(**rows[0]) + post_json = post.to_json() + post_json["tags"] = {} + for tag in post.tag_string.split(" "): + tag_rows = await app.db.execute_fetchall( + "select category from tags where name = ?", (tag,) + ) + category = TagCategory(tag_rows[0][0]) + category_str = category.to_string() + if category_str not in post_json["tags"]: + post_json["tags"][category_str] = [] + post_json["tags"][category_str].append(tag) + + return {"posts": [post_json]} + + +if __name__ == "__main__": + logging.basicConfig( + level=logging.DEBUG if os.environ.get("DEBUG") else logging.INFO + ) + config = Config() + config.bind = ["0.0.0.0:1334"] + asyncio.run(serve(app, config))