diff --git a/build_database.py b/build_database.py index c91621e..f40d6cf 100644 --- a/build_database.py +++ b/build_database.py @@ -98,8 +98,8 @@ async def main_with_ctx(ctx, wanted_date): # write to output log.info("copying temp to output") - with output_path.open(mode="wb") as output_fd: - shutil.copyfileobj(temp_fd, output_fd) + with output_path.open(mode="wb") as output_fdout: + shutil.copyfileobj(temp_fd, output_fdout) # decompress for url_type, _url in urls.items(): diff --git a/e621_api_cloner.py b/e621_api_cloner.py index a19edf6..0daa756 100644 --- a/e621_api_cloner.py +++ b/e621_api_cloner.py @@ -150,16 +150,33 @@ async def posts_json(): if len(tags) != 1: return "unsupported query (only 1 tag)", 400 - md5_tag = tags[0] - if not md5_tag.startswith("md5:"): - return "unsupported query (only md5 for now)", 400 + first_tag = tags[0] + tag_parts = first_tag.split(":") + if len(tag_parts) != 2: + return "unsupported query (tag must be a:b format)", 400 - md5_prefix, md5_hexencoded_hash = md5_tag.split(":") - assert md5_prefix == "md5" + type, value = tag_parts + if type not in ("md5", "id"): + return ( + "unsupported query (tag must be a:b format, and a must be id or md5)", + 400, + ) - rows = await app.db.execute_fetchall( - "select * from posts where md5 = ?", (md5_hexencoded_hash,) - ) + if type == "md5": + md5_hexencoded_hash = value + rows = await app.db.execute_fetchall( + "select * from posts where md5 = ?", (md5_hexencoded_hash,) + ) + elif type == "id": + try: + value = int(value) + except ValueError: + return "id:x x must be int", 400 + rows = await app.db.execute_fetchall( + "select * from posts where id = ?", (value,) + ) + else: + raise AssertionError("must be id or md5") if not rows: return {"posts": []}