e621_api_cloner/e621_api_cloner.py

142 lines
3.3 KiB
Python
Raw Normal View History

2022-08-29 03:26:33 +00:00
#!/usr/bin/env python3
import asyncio
import logging
import aiosqlite
import sqlite3
import sys
import os
import enum
from datetime import datetime
from dataclasses import dataclass, asdict
from hypercorn.asyncio import serve, Config
from quart import Quart, request, jsonify
log = logging.getLogger(__name__)
app = Quart(__name__)
@app.before_serving
async def app_before_serving():
db_path = sys.argv[1]
app.db = await aiosqlite.connect(db_path)
app.db.row_factory = sqlite3.Row
@app.after_serving
async def app_after_serving():
await app.db.close()
@dataclass
class Tag:
id: int
name: str
category: int
post_count: int
@dataclass
class Post:
id: int
uploader_id: int
created_at: int
md5: str
source: str
rating: str
tag_string: str
is_deleted: int
is_pending: int
is_flagged: int
score: int
up_score: int
down_score: int
is_rating_locked: int
def to_json(self):
post = asdict(self)
post["created_at"] = datetime.fromtimestamp(self.created_at).isoformat()
post["score"] = {
"up": self.up_score,
"down": self.down_score,
"total": self.score,
}
post.pop("up_score")
post.pop("down_score")
post["flags"] = {
"pending": post.pop("is_pending"),
"flagged": post.pop("is_flagged"),
"deleted": post.pop("is_deleted"),
"rating_locked": post.pop("is_rating_locked"),
}
return post
class TagCategory(enum.IntEnum):
GENERAL = 0
ARTIST = 1
COPYRIGHT = 3
CHARACTER = 4
SPECIES = 5
DEPRECATED = 6
METADATA = 7
LORE = 8
def to_string(self) -> str:
return {
self.GENERAL: "general",
self.ARTIST: "artist",
self.COPYRIGHT: "copyright",
self.CHARACTER: "character",
self.METADATA: "metadata",
self.DEPRECATED: "deprecated",
self.SPECIES: "species",
self.LORE: "lore",
}[self]
@app.route("/posts.json")
async def posts_json():
tag_str = request.args["tags"]
tags = tag_str.split(" ")
if len(tags) != 1:
return "unsupported query (only 1 tag)", 400
md5_tag = tags[0]
if not md5_tag.startswith("md5:"):
return "unsupported query (only md5 for now)", 400
md5_prefix, md5_hexencoded_hash = md5_tag.split(":")
assert md5_prefix == "md5"
rows = await app.db.execute_fetchall(
"select * from posts where md5 = ?", (md5_hexencoded_hash,)
)
if not rows:
return {"posts": []}
post = Post(**rows[0])
post_json = post.to_json()
post_json["tags"] = {}
for tag in post.tag_string.split(" "):
tag_rows = await app.db.execute_fetchall(
"select category from tags where name = ?", (tag,)
)
category = TagCategory(tag_rows[0][0])
category_str = category.to_string()
if category_str not in post_json["tags"]:
post_json["tags"][category_str] = []
post_json["tags"][category_str].append(tag)
return {"posts": [post_json]}
if __name__ == "__main__":
logging.basicConfig(
level=logging.DEBUG if os.environ.get("DEBUG") else logging.INFO
)
config = Config()
config.bind = ["0.0.0.0:1334"]
asyncio.run(serve(app, config))