e621_api_cloner/e621_api_cloner.py

194 lines
5 KiB
Python

#!/usr/bin/env python3
import asyncio
import logging
import aiosqlite
import sqlite3
import sys
import os
import enum
from pathlib import Path
from datetime import datetime
from dataclasses import dataclass, asdict
from hypercorn.asyncio import serve, Config
from quart import Quart, request, jsonify
log = logging.getLogger(__name__)
app = Quart(__name__)
@app.before_serving
async def app_before_serving():
logging.basicConfig(
level=logging.DEBUG if os.environ.get("DEBUG") else logging.INFO
)
db_path = os.environ["DBPATH"]
app.db = await aiosqlite.connect(db_path)
app.db.row_factory = sqlite3.Row
# contains api keys
maybe_authfile_path = os.environ.get("AUTHFILE")
app.maybe_authfile = Path(maybe_authfile_path) if maybe_authfile_path else None
if app.maybe_authfile:
app.apikeys = {}
log.info("loading auth with api keys")
with app.maybe_authfile.open(mode="r") as fd:
for line in fd:
api_key, *user_name = line.split(" ")
app.apikeys[api_key] = (" ".join(user_name)).strip()
@app.after_serving
async def app_after_serving():
log.info("possibly optimizing database")
await app.db.execute("PRAGMA analysis_limit=400")
await app.db.execute("PRAGMA optimize")
log.info("closing connection")
await app.db.close()
@dataclass
class Tag:
id: int
name: str
category: int
post_count: int
@dataclass
class Post:
id: int
uploader_id: int
created_at: int
md5: str
source: str
rating: str
tag_string: str
is_deleted: int
is_pending: int
is_flagged: int
score: int
up_score: int
down_score: int
is_rating_locked: int
def to_json(self):
post = asdict(self)
post["created_at"] = datetime.fromtimestamp(self.created_at).isoformat()
post["score"] = {
"up": self.up_score,
"down": self.down_score,
"total": self.score,
}
post.pop("up_score")
post.pop("down_score")
post["file"] = {"md5": post.pop("md5")}
post["flags"] = {
"pending": post.pop("is_pending"),
"flagged": post.pop("is_flagged"),
"deleted": post.pop("is_deleted"),
"rating_locked": post.pop("is_rating_locked"),
}
return post
class TagCategory(enum.IntEnum):
GENERAL = 0
ARTIST = 1
COPYRIGHT = 3
CHARACTER = 4
SPECIES = 5
DEPRECATED = 6
METADATA = 7
LORE = 8
def to_string(self) -> str:
return {
self.GENERAL: "general",
self.ARTIST: "artist",
self.COPYRIGHT: "copyright",
self.CHARACTER: "character",
self.METADATA: "meta",
self.DEPRECATED: "deprecated",
self.SPECIES: "species",
self.LORE: "lore",
}[self]
async def maybe_do_authentication():
if not app.maybe_authfile:
return None
auth_line = request.headers.get("authorization")
if not auth_line:
return "authorization header required", 400
auth_type, auth_data = auth_line.split(" ")
if auth_type != "Bearer":
log.warn("invalid auth type")
return "invalid auth token type (must be Bearer)", 400
auth_name = app.apikeys.get(auth_data)
if auth_name is None:
log.warn("invalid auth value")
return "invalid auth token (unknown api key)", 400
log.info("logged in as %r", auth_name)
return None
@app.route("/posts.json")
async def posts_json():
res = await maybe_do_authentication()
if res:
return res
tag_str = request.args["tags"]
tags = tag_str.split(" ")
if len(tags) != 1:
return "unsupported query (only 1 tag)", 400
first_tag = tags[0]
tag_parts = first_tag.split(":")
if len(tag_parts) != 2:
return "unsupported query (tag must be a:b format)", 400
type, value = tag_parts
if type not in ("md5", "id"):
return (
"unsupported query (tag must be a:b format, and a must be id or md5)",
400,
)
if type == "md5":
md5_hexencoded_hash = value
rows = await app.db.execute_fetchall(
"select * from posts where md5 = ?", (md5_hexencoded_hash,)
)
elif type == "id":
try:
value = int(value)
except ValueError:
return "id:x x must be int", 400
rows = await app.db.execute_fetchall(
"select * from posts where id = ?", (value,)
)
else:
raise AssertionError("must be id or md5")
if not rows:
return {"posts": []}
post = Post(**rows[0])
post_json = post.to_json()
post_json["tags"] = {category.to_string(): [] for category in TagCategory}
for tag in post.tag_string.split(" "):
tag_rows = await app.db.execute_fetchall(
"select category from tags where name = ?", (tag,)
)
category = TagCategory(tag_rows[0][0])
category_str = category.to_string()
post_json["tags"][category_str].append(tag)
return {"posts": [post_json]}