diff --git a/main.py b/main.py index 5ca42aa..40a6b06 100644 --- a/main.py +++ b/main.py @@ -194,7 +194,21 @@ async def fetch_post(ctx, md5) -> Optional[dict]: if not rows: return None assert len(rows) == 1 - return json.loads(rows[0][0]) + post = json.loads(rows[0][0]) + post_rating = post["rating"] + match post_rating: + case "g": + rating_tag = "general" + case "s": + rating_tag = "sensitive" + case "q": + rating_tag = "questionable" + case "e": + rating_tag = "explicit" + case _: + raise AssertionError("invalid post rating {post_rating!r}") + post["tag_string"] = post["tag_string"] + " " + rating_tag + return post async def insert_post(ctx, post): @@ -250,8 +264,8 @@ async def fight(ctx): def score(danbooru_tags: Set[str], interrogator_tags: Set[str]) -> decimal.Decimal: - tags_in_danbooru = danbooru_tags | interrogator_tags - tags_not_in_danbooru = danbooru_tags - interrogator_tags + tags_in_danbooru = danbooru_tags.intersection(interrogator_tags) + tags_not_in_danbooru = interrogator_tags - danbooru_tags return decimal.Decimal( len(tags_in_danbooru) - len(tags_not_in_danbooru) ) / decimal.Decimal(len(danbooru_tags))