diff --git a/e621_api_cloner.py b/e621_api_cloner.py index 5c5749b..8273ca5 100644 --- a/e621_api_cloner.py +++ b/e621_api_cloner.py @@ -7,6 +7,7 @@ import sqlite3 import sys import os import enum +from pathlib import Path from datetime import datetime from dataclasses import dataclass, asdict from hypercorn.asyncio import serve, Config @@ -19,10 +20,25 @@ app = Quart(__name__) @app.before_serving async def app_before_serving(): + logging.basicConfig( + level=logging.DEBUG if os.environ.get("DEBUG") else logging.INFO + ) + db_path = os.environ["DBPATH"] app.db = await aiosqlite.connect(db_path) app.db.row_factory = sqlite3.Row + # contains api keys + maybe_authfile_path = os.environ.get("AUTHFILE") + app.maybe_authfile = Path(maybe_authfile_path) if maybe_authfile_path else None + if app.maybe_authfile: + app.apikeys = {} + log.info("loading auth with api keys") + with app.maybe_authfile.open(mode="r") as fd: + for line in fd: + api_key, *user_name = line.split(" ") + app.apikeys[api_key] = (" ".join(user_name)).strip() + @app.after_serving async def app_after_serving(): @@ -96,8 +112,34 @@ class TagCategory(enum.IntEnum): }[self] +async def maybe_do_authentication(): + if not app.maybe_authfile: + return None + + auth_line = request.headers.get("authorization") + if not auth_line: + return "authorization header required", 400 + + auth_type, auth_data = auth_line.split(" ") + if auth_type != "Bearer": + log.warn("invalid auth type") + return "invalid auth token type (must be Bearer)", 400 + + auth_name = app.apikeys.get(auth_data) + if auth_name is None: + log.warn("invalid auth value") + return "invalid auth token (unknown api key)", 400 + + log.info("logged in as %r", auth_name) + return None + + @app.route("/posts.json") async def posts_json(): + res = await maybe_do_authentication() + if res: + return res + tag_str = request.args["tags"] tags = tag_str.split(" ") if len(tags) != 1: @@ -130,12 +172,3 @@ async def posts_json(): post_json["tags"][category_str].append(tag) return {"posts": [post_json]} - - -if __name__ == "__main__": - logging.basicConfig( - level=logging.DEBUG if os.environ.get("DEBUG") else logging.INFO - ) - config = Config() - config.bind = ["0.0.0.0:1334"] - asyncio.run(serve(app, config))