178 lines
4.5 KiB
Python
178 lines
4.5 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import asyncio
|
|
import logging
|
|
import aiosqlite
|
|
import sqlite3
|
|
import sys
|
|
import os
|
|
import enum
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from dataclasses import dataclass, asdict
|
|
from hypercorn.asyncio import serve, Config
|
|
from quart import Quart, request, jsonify
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
app = Quart(__name__)
|
|
|
|
|
|
@app.before_serving
|
|
async def app_before_serving():
|
|
logging.basicConfig(
|
|
level=logging.DEBUG if os.environ.get("DEBUG") else logging.INFO
|
|
)
|
|
|
|
db_path = os.environ["DBPATH"]
|
|
app.db = await aiosqlite.connect(db_path)
|
|
app.db.row_factory = sqlite3.Row
|
|
|
|
# contains api keys
|
|
maybe_authfile_path = os.environ.get("AUTHFILE")
|
|
app.maybe_authfile = Path(maybe_authfile_path) if maybe_authfile_path else None
|
|
if app.maybe_authfile:
|
|
app.apikeys = {}
|
|
log.info("loading auth with api keys")
|
|
with app.maybe_authfile.open(mode="r") as fd:
|
|
for line in fd:
|
|
api_key, *user_name = line.split(" ")
|
|
app.apikeys[api_key] = (" ".join(user_name)).strip()
|
|
|
|
|
|
@app.after_serving
|
|
async def app_after_serving():
|
|
log.info("possibly optimizing database")
|
|
await app.db.execute("PRAGMA analysis_limit=400")
|
|
await app.db.execute("PRAGMA optimize")
|
|
log.info("closing connection")
|
|
await app.db.close()
|
|
|
|
|
|
@dataclass
|
|
class Tag:
|
|
id: int
|
|
name: str
|
|
category: int
|
|
post_count: int
|
|
|
|
|
|
@dataclass
|
|
class Post:
|
|
id: int
|
|
uploader_id: int
|
|
created_at: int
|
|
md5: str
|
|
source: str
|
|
rating: str
|
|
tag_string: str
|
|
is_deleted: int
|
|
is_pending: int
|
|
is_flagged: int
|
|
score: int
|
|
up_score: int
|
|
down_score: int
|
|
is_rating_locked: int
|
|
|
|
def to_json(self):
|
|
post = asdict(self)
|
|
post["created_at"] = datetime.fromtimestamp(self.created_at).isoformat()
|
|
post["score"] = {
|
|
"up": self.up_score,
|
|
"down": self.down_score,
|
|
"total": self.score,
|
|
}
|
|
post.pop("up_score")
|
|
post.pop("down_score")
|
|
post["flags"] = {
|
|
"pending": post.pop("is_pending"),
|
|
"flagged": post.pop("is_flagged"),
|
|
"deleted": post.pop("is_deleted"),
|
|
"rating_locked": post.pop("is_rating_locked"),
|
|
}
|
|
return post
|
|
|
|
|
|
class TagCategory(enum.IntEnum):
|
|
GENERAL = 0
|
|
ARTIST = 1
|
|
COPYRIGHT = 3
|
|
CHARACTER = 4
|
|
SPECIES = 5
|
|
DEPRECATED = 6
|
|
METADATA = 7
|
|
LORE = 8
|
|
|
|
def to_string(self) -> str:
|
|
return {
|
|
self.GENERAL: "general",
|
|
self.ARTIST: "artist",
|
|
self.COPYRIGHT: "copyright",
|
|
self.CHARACTER: "character",
|
|
self.METADATA: "metadata",
|
|
self.DEPRECATED: "deprecated",
|
|
self.SPECIES: "species",
|
|
self.LORE: "lore",
|
|
}[self]
|
|
|
|
|
|
async def maybe_do_authentication():
|
|
if not app.maybe_authfile:
|
|
return None
|
|
|
|
auth_line = request.headers.get("authorization")
|
|
if not auth_line:
|
|
return "authorization header required", 400
|
|
|
|
auth_type, auth_data = auth_line.split(" ")
|
|
if auth_type != "Bearer":
|
|
log.warn("invalid auth type")
|
|
return "invalid auth token type (must be Bearer)", 400
|
|
|
|
auth_name = app.apikeys.get(auth_data)
|
|
if auth_name is None:
|
|
log.warn("invalid auth value")
|
|
return "invalid auth token (unknown api key)", 400
|
|
|
|
log.info("logged in as %r", auth_name)
|
|
return None
|
|
|
|
|
|
@app.route("/posts.json")
|
|
async def posts_json():
|
|
res = await maybe_do_authentication()
|
|
if res:
|
|
return res
|
|
|
|
tag_str = request.args["tags"]
|
|
tags = tag_str.split(" ")
|
|
if len(tags) != 1:
|
|
return "unsupported query (only 1 tag)", 400
|
|
|
|
md5_tag = tags[0]
|
|
if not md5_tag.startswith("md5:"):
|
|
return "unsupported query (only md5 for now)", 400
|
|
|
|
md5_prefix, md5_hexencoded_hash = md5_tag.split(":")
|
|
assert md5_prefix == "md5"
|
|
|
|
rows = await app.db.execute_fetchall(
|
|
"select * from posts where md5 = ?", (md5_hexencoded_hash,)
|
|
)
|
|
if not rows:
|
|
return {"posts": []}
|
|
|
|
post = Post(**rows[0])
|
|
post_json = post.to_json()
|
|
post_json["tags"] = {}
|
|
for tag in post.tag_string.split(" "):
|
|
tag_rows = await app.db.execute_fetchall(
|
|
"select category from tags where name = ?", (tag,)
|
|
)
|
|
category = TagCategory(tag_rows[0][0])
|
|
category_str = category.to_string()
|
|
if category_str not in post_json["tags"]:
|
|
post_json["tags"][category_str] = []
|
|
post_json["tags"][category_str].append(tag)
|
|
|
|
return {"posts": [post_json]}
|