#!/usr/bin/env python3 import asyncio import logging import aiosqlite import sqlite3 import sys import os import enum from datetime import datetime from dataclasses import dataclass, asdict from hypercorn.asyncio import serve, Config from quart import Quart, request, jsonify log = logging.getLogger(__name__) app = Quart(__name__) @app.before_serving async def app_before_serving(): db_path = sys.argv[1] app.db = await aiosqlite.connect(db_path) app.db.row_factory = sqlite3.Row @app.after_serving async def app_after_serving(): await app.db.close() @dataclass class Tag: id: int name: str category: int post_count: int @dataclass class Post: id: int uploader_id: int created_at: int md5: str source: str rating: str tag_string: str is_deleted: int is_pending: int is_flagged: int score: int up_score: int down_score: int is_rating_locked: int def to_json(self): post = asdict(self) post["created_at"] = datetime.fromtimestamp(self.created_at).isoformat() post["score"] = { "up": self.up_score, "down": self.down_score, "total": self.score, } post.pop("up_score") post.pop("down_score") post["flags"] = { "pending": post.pop("is_pending"), "flagged": post.pop("is_flagged"), "deleted": post.pop("is_deleted"), "rating_locked": post.pop("is_rating_locked"), } return post class TagCategory(enum.IntEnum): GENERAL = 0 ARTIST = 1 COPYRIGHT = 3 CHARACTER = 4 SPECIES = 5 DEPRECATED = 6 METADATA = 7 LORE = 8 def to_string(self) -> str: return { self.GENERAL: "general", self.ARTIST: "artist", self.COPYRIGHT: "copyright", self.CHARACTER: "character", self.METADATA: "metadata", self.DEPRECATED: "deprecated", self.SPECIES: "species", self.LORE: "lore", }[self] @app.route("/posts.json") async def posts_json(): tag_str = request.args["tags"] tags = tag_str.split(" ") if len(tags) != 1: return "unsupported query (only 1 tag)", 400 md5_tag = tags[0] if not md5_tag.startswith("md5:"): return "unsupported query (only md5 for now)", 400 md5_prefix, md5_hexencoded_hash = md5_tag.split(":") assert md5_prefix == "md5" rows = await app.db.execute_fetchall( "select * from posts where md5 = ?", (md5_hexencoded_hash,) ) if not rows: return {"posts": []} post = Post(**rows[0]) post_json = post.to_json() post_json["tags"] = {} for tag in post.tag_string.split(" "): tag_rows = await app.db.execute_fetchall( "select category from tags where name = ?", (tag,) ) category = TagCategory(tag_rows[0][0]) category_str = category.to_string() if category_str not in post_json["tags"]: post_json["tags"][category_str] = [] post_json["tags"][category_str].append(tag) return {"posts": [post_json]} if __name__ == "__main__": logging.basicConfig( level=logging.DEBUG if os.environ.get("DEBUG") else logging.INFO ) config = Config() config.bind = ["0.0.0.0:1334"] asyncio.run(serve(app, config))