e621_api_cloner/e621_api_cloner.py

195 lines
5 KiB
Python
Raw Permalink Normal View History

2022-08-29 03:26:33 +00:00
#!/usr/bin/env python3
import asyncio
import logging
import aiosqlite
import sqlite3
import sys
import os
import enum
2022-08-31 00:09:54 +00:00
from pathlib import Path
2022-08-29 03:26:33 +00:00
from datetime import datetime
from dataclasses import dataclass, asdict
from hypercorn.asyncio import serve, Config
from quart import Quart, request, jsonify
log = logging.getLogger(__name__)
app = Quart(__name__)
@app.before_serving
async def app_before_serving():
2022-08-31 00:09:54 +00:00
logging.basicConfig(
level=logging.DEBUG if os.environ.get("DEBUG") else logging.INFO
)
2022-08-29 03:29:01 +00:00
db_path = os.environ["DBPATH"]
2022-08-29 03:26:33 +00:00
app.db = await aiosqlite.connect(db_path)
app.db.row_factory = sqlite3.Row
2022-08-31 00:09:54 +00:00
# contains api keys
maybe_authfile_path = os.environ.get("AUTHFILE")
app.maybe_authfile = Path(maybe_authfile_path) if maybe_authfile_path else None
if app.maybe_authfile:
app.apikeys = {}
log.info("loading auth with api keys")
with app.maybe_authfile.open(mode="r") as fd:
for line in fd:
api_key, *user_name = line.split(" ")
app.apikeys[api_key] = (" ".join(user_name)).strip()
2022-08-29 03:26:33 +00:00
@app.after_serving
async def app_after_serving():
log.info("possibly optimizing database")
await app.db.execute("PRAGMA analysis_limit=400")
await app.db.execute("PRAGMA optimize")
log.info("closing connection")
2022-08-29 03:26:33 +00:00
await app.db.close()
@dataclass
class Tag:
id: int
name: str
category: int
post_count: int
@dataclass
class Post:
id: int
uploader_id: int
created_at: int
md5: str
source: str
rating: str
tag_string: str
is_deleted: int
is_pending: int
is_flagged: int
score: int
up_score: int
down_score: int
is_rating_locked: int
def to_json(self):
post = asdict(self)
post["created_at"] = datetime.fromtimestamp(self.created_at).isoformat()
post["score"] = {
"up": self.up_score,
"down": self.down_score,
"total": self.score,
}
post.pop("up_score")
post.pop("down_score")
2022-09-18 19:20:33 +00:00
post["file"] = {"md5": post.pop("md5")}
2022-08-29 03:26:33 +00:00
post["flags"] = {
"pending": post.pop("is_pending"),
"flagged": post.pop("is_flagged"),
"deleted": post.pop("is_deleted"),
"rating_locked": post.pop("is_rating_locked"),
}
return post
class TagCategory(enum.IntEnum):
GENERAL = 0
ARTIST = 1
COPYRIGHT = 3
CHARACTER = 4
SPECIES = 5
DEPRECATED = 6
METADATA = 7
LORE = 8
def to_string(self) -> str:
return {
self.GENERAL: "general",
self.ARTIST: "artist",
self.COPYRIGHT: "copyright",
self.CHARACTER: "character",
2022-09-18 19:20:33 +00:00
self.METADATA: "meta",
2022-08-29 03:26:33 +00:00
self.DEPRECATED: "deprecated",
self.SPECIES: "species",
self.LORE: "lore",
}[self]
2022-08-31 00:09:54 +00:00
async def maybe_do_authentication():
if not app.maybe_authfile:
return None
auth_line = request.headers.get("authorization")
if not auth_line:
return "authorization header required", 400
auth_type, auth_data = auth_line.split(" ")
if auth_type != "Bearer":
log.warn("invalid auth type")
return "invalid auth token type (must be Bearer)", 400
auth_name = app.apikeys.get(auth_data)
if auth_name is None:
log.warn("invalid auth value")
return "invalid auth token (unknown api key)", 400
log.info("logged in as %r", auth_name)
return None
2022-08-29 03:26:33 +00:00
@app.route("/posts.json")
async def posts_json():
2022-08-31 00:09:54 +00:00
res = await maybe_do_authentication()
if res:
return res
2022-08-29 03:26:33 +00:00
tag_str = request.args["tags"]
tags = tag_str.split(" ")
if len(tags) != 1:
return "unsupported query (only 1 tag)", 400
2024-02-12 03:08:29 +00:00
first_tag = tags[0]
tag_parts = first_tag.split(":")
if len(tag_parts) != 2:
return "unsupported query (tag must be a:b format)", 400
type, value = tag_parts
if type not in ("md5", "id"):
return (
"unsupported query (tag must be a:b format, and a must be id or md5)",
400,
)
2022-08-29 03:26:33 +00:00
2024-02-12 03:08:29 +00:00
if type == "md5":
md5_hexencoded_hash = value
rows = await app.db.execute_fetchall(
"select * from posts where md5 = ?", (md5_hexencoded_hash,)
)
elif type == "id":
try:
value = int(value)
except ValueError:
return "id:x x must be int", 400
rows = await app.db.execute_fetchall(
"select * from posts where id = ?", (value,)
)
else:
raise AssertionError("must be id or md5")
2022-08-29 03:26:33 +00:00
if not rows:
return {"posts": []}
post = Post(**rows[0])
post_json = post.to_json()
2022-09-18 19:20:33 +00:00
post_json["tags"] = {category.to_string(): [] for category in TagCategory}
2022-08-29 03:26:33 +00:00
for tag in post.tag_string.split(" "):
tag_rows = await app.db.execute_fetchall(
"select category from tags where name = ?", (tag,)
)
category = TagCategory(tag_rows[0][0])
category_str = category.to_string()
post_json["tags"][category_str].append(tag)
return {"posts": [post_json]}