e621_api_cloner/e621_api_cloner.py
2022-08-30 21:10:13 -03:00

178 lines
4.5 KiB
Python

#!/usr/bin/env python3
import asyncio
import logging
import aiosqlite
import sqlite3
import sys
import os
import enum
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, asdict
from hypercorn.asyncio import serve, Config
from quart import Quart, request, jsonify
log = logging.getLogger(__name__)
app = Quart(__name__)
@app.before_serving
async def app_before_serving():
logging.basicConfig(
level=logging.DEBUG if os.environ.get("DEBUG") else logging.INFO
)
db_path = os.environ["DBPATH"]
app.db = await aiosqlite.connect(db_path)
app.db.row_factory = sqlite3.Row
# contains api keys
maybe_authfile_path = os.environ.get("AUTHFILE")
app.maybe_authfile = Path(maybe_authfile_path) if maybe_authfile_path else None
if app.maybe_authfile:
app.apikeys = {}
log.info("loading auth with api keys")
with app.maybe_authfile.open(mode="r") as fd:
for line in fd:
api_key, *user_name = line.split(" ")
app.apikeys[api_key] = (" ".join(user_name)).strip()
@app.after_serving
async def app_after_serving():
log.info("possibly optimizing database")
await app.db.execute("PRAGMA analysis_limit=400")
await app.db.execute("PRAGMA optimize")
log.info("closing connection")
await app.db.close()
@dataclass
class Tag:
id: int
name: str
category: int
post_count: int
@dataclass
class Post:
id: int
uploader_id: int
created_at: int
md5: str
source: str
rating: str
tag_string: str
is_deleted: int
is_pending: int
is_flagged: int
score: int
up_score: int
down_score: int
is_rating_locked: int
def to_json(self):
post = asdict(self)
post["created_at"] = datetime.fromtimestamp(self.created_at).isoformat()
post["score"] = {
"up": self.up_score,
"down": self.down_score,
"total": self.score,
}
post.pop("up_score")
post.pop("down_score")
post["flags"] = {
"pending": post.pop("is_pending"),
"flagged": post.pop("is_flagged"),
"deleted": post.pop("is_deleted"),
"rating_locked": post.pop("is_rating_locked"),
}
return post
class TagCategory(enum.IntEnum):
GENERAL = 0
ARTIST = 1
COPYRIGHT = 3
CHARACTER = 4
SPECIES = 5
DEPRECATED = 6
METADATA = 7
LORE = 8
def to_string(self) -> str:
return {
self.GENERAL: "general",
self.ARTIST: "artist",
self.COPYRIGHT: "copyright",
self.CHARACTER: "character",
self.METADATA: "metadata",
self.DEPRECATED: "deprecated",
self.SPECIES: "species",
self.LORE: "lore",
}[self]
async def maybe_do_authentication():
if not app.maybe_authfile:
return None
auth_line = request.headers.get("authorization")
if not auth_line:
return "authorization header required", 400
auth_type, auth_data = auth_line.split(" ")
if auth_type != "Bearer":
log.warn("invalid auth type")
return "invalid auth token type (must be Bearer)", 400
auth_name = app.apikeys.get(auth_data)
if auth_name is None:
log.warn("invalid auth value")
return "invalid auth token (unknown api key)", 400
log.info("logged in as %r", auth_name)
return None
@app.route("/posts.json")
async def posts_json():
res = await maybe_do_authentication()
if res:
return res
tag_str = request.args["tags"]
tags = tag_str.split(" ")
if len(tags) != 1:
return "unsupported query (only 1 tag)", 400
md5_tag = tags[0]
if not md5_tag.startswith("md5:"):
return "unsupported query (only md5 for now)", 400
md5_prefix, md5_hexencoded_hash = md5_tag.split(":")
assert md5_prefix == "md5"
rows = await app.db.execute_fetchall(
"select * from posts where md5 = ?", (md5_hexencoded_hash,)
)
if not rows:
return {"posts": []}
post = Post(**rows[0])
post_json = post.to_json()
post_json["tags"] = {}
for tag in post.tag_string.split(" "):
tag_rows = await app.db.execute_fetchall(
"select category from tags where name = ?", (tag,)
)
category = TagCategory(tag_rows[0][0])
category_str = category.to_string()
if category_str not in post_json["tags"]:
post_json["tags"][category_str] = []
post_json["tags"][category_str].append(tag)
return {"posts": [post_json]}