add e621 api cloner
This commit is contained in:
parent
dab0de4448
commit
fe3a81e614
1 changed files with 141 additions and 0 deletions
141
e621_api_cloner.py
Normal file
141
e621_api_cloner.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import aiosqlite
|
||||
import sqlite3
|
||||
import sys
|
||||
import os
|
||||
import enum
|
||||
from datetime import datetime
|
||||
from dataclasses import dataclass, asdict
|
||||
from hypercorn.asyncio import serve, Config
|
||||
from quart import Quart, request, jsonify
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
app = Quart(__name__)
|
||||
|
||||
|
||||
@app.before_serving
|
||||
async def app_before_serving():
|
||||
db_path = sys.argv[1]
|
||||
app.db = await aiosqlite.connect(db_path)
|
||||
app.db.row_factory = sqlite3.Row
|
||||
|
||||
|
||||
@app.after_serving
|
||||
async def app_after_serving():
|
||||
await app.db.close()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tag:
|
||||
id: int
|
||||
name: str
|
||||
category: int
|
||||
post_count: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class Post:
|
||||
id: int
|
||||
uploader_id: int
|
||||
created_at: int
|
||||
md5: str
|
||||
source: str
|
||||
rating: str
|
||||
tag_string: str
|
||||
is_deleted: int
|
||||
is_pending: int
|
||||
is_flagged: int
|
||||
score: int
|
||||
up_score: int
|
||||
down_score: int
|
||||
is_rating_locked: int
|
||||
|
||||
def to_json(self):
|
||||
post = asdict(self)
|
||||
post["created_at"] = datetime.fromtimestamp(self.created_at).isoformat()
|
||||
post["score"] = {
|
||||
"up": self.up_score,
|
||||
"down": self.down_score,
|
||||
"total": self.score,
|
||||
}
|
||||
post.pop("up_score")
|
||||
post.pop("down_score")
|
||||
post["flags"] = {
|
||||
"pending": post.pop("is_pending"),
|
||||
"flagged": post.pop("is_flagged"),
|
||||
"deleted": post.pop("is_deleted"),
|
||||
"rating_locked": post.pop("is_rating_locked"),
|
||||
}
|
||||
return post
|
||||
|
||||
|
||||
class TagCategory(enum.IntEnum):
|
||||
GENERAL = 0
|
||||
ARTIST = 1
|
||||
COPYRIGHT = 3
|
||||
CHARACTER = 4
|
||||
SPECIES = 5
|
||||
DEPRECATED = 6
|
||||
METADATA = 7
|
||||
LORE = 8
|
||||
|
||||
def to_string(self) -> str:
|
||||
return {
|
||||
self.GENERAL: "general",
|
||||
self.ARTIST: "artist",
|
||||
self.COPYRIGHT: "copyright",
|
||||
self.CHARACTER: "character",
|
||||
self.METADATA: "metadata",
|
||||
self.DEPRECATED: "deprecated",
|
||||
self.SPECIES: "species",
|
||||
self.LORE: "lore",
|
||||
}[self]
|
||||
|
||||
|
||||
@app.route("/posts.json")
|
||||
async def posts_json():
|
||||
tag_str = request.args["tags"]
|
||||
tags = tag_str.split(" ")
|
||||
if len(tags) != 1:
|
||||
return "unsupported query (only 1 tag)", 400
|
||||
|
||||
md5_tag = tags[0]
|
||||
if not md5_tag.startswith("md5:"):
|
||||
return "unsupported query (only md5 for now)", 400
|
||||
|
||||
md5_prefix, md5_hexencoded_hash = md5_tag.split(":")
|
||||
assert md5_prefix == "md5"
|
||||
|
||||
rows = await app.db.execute_fetchall(
|
||||
"select * from posts where md5 = ?", (md5_hexencoded_hash,)
|
||||
)
|
||||
if not rows:
|
||||
return {"posts": []}
|
||||
|
||||
post = Post(**rows[0])
|
||||
post_json = post.to_json()
|
||||
post_json["tags"] = {}
|
||||
for tag in post.tag_string.split(" "):
|
||||
tag_rows = await app.db.execute_fetchall(
|
||||
"select category from tags where name = ?", (tag,)
|
||||
)
|
||||
category = TagCategory(tag_rows[0][0])
|
||||
category_str = category.to_string()
|
||||
if category_str not in post_json["tags"]:
|
||||
post_json["tags"][category_str] = []
|
||||
post_json["tags"][category_str].append(tag)
|
||||
|
||||
return {"posts": [post_json]}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG if os.environ.get("DEBUG") else logging.INFO
|
||||
)
|
||||
config = Config()
|
||||
config.bind = ["0.0.0.0:1334"]
|
||||
asyncio.run(serve(app, config))
|
Loading…
Reference in a new issue