diff --git a/e621_api_cloner.py b/e621_api_cloner.py index a19edf6..0daa756 100644 --- a/e621_api_cloner.py +++ b/e621_api_cloner.py @@ -150,16 +150,33 @@ async def posts_json(): if len(tags) != 1: return "unsupported query (only 1 tag)", 400 - md5_tag = tags[0] - if not md5_tag.startswith("md5:"): - return "unsupported query (only md5 for now)", 400 + first_tag = tags[0] + tag_parts = first_tag.split(":") + if len(tag_parts) != 2: + return "unsupported query (tag must be a:b format)", 400 - md5_prefix, md5_hexencoded_hash = md5_tag.split(":") - assert md5_prefix == "md5" + type, value = tag_parts + if type not in ("md5", "id"): + return ( + "unsupported query (tag must be a:b format, and a must be id or md5)", + 400, + ) - rows = await app.db.execute_fetchall( - "select * from posts where md5 = ?", (md5_hexencoded_hash,) - ) + if type == "md5": + md5_hexencoded_hash = value + rows = await app.db.execute_fetchall( + "select * from posts where md5 = ?", (md5_hexencoded_hash,) + ) + elif type == "id": + try: + value = int(value) + except ValueError: + return "id:x x must be int", 400 + rows = await app.db.execute_fetchall( + "select * from posts where id = ?", (value,) + ) + else: + raise AssertionError("must be id or md5") if not rows: return {"posts": []}