diff --git a/main.py b/main.py index 40a6b06..ed9ca39 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import os import json import aiohttp import pprint @@ -8,10 +9,10 @@ import asyncio import aiosqlite import base64 import logging -from collections import defaultdict +from collections import defaultdict, Counter from pathlib import Path from urllib.parse import urlparse -from typing import Any, Optional, List, Set +from typing import Any, Optional, List, Set, Tuple from aiolimiter import AsyncLimiter from dataclasses import dataclass, field @@ -33,8 +34,32 @@ class Interrogator: model_id: str address: str + def _process(self, lst): + return lst + + async def fetch_tags(self, ctx, md5_hash): + rows = await ctx.db.execute_fetchall( + "select output_tag_string from interrogated_posts where md5 = ? and model_name = ?", + (md5_hash, self.model_id), + ) + assert len(rows) == 1 # run 'fight' mode, there's missing posts + tag_string = rows[0][0] + return self._process(tag_string.split()) + class DDInterrogator(Interrogator): + def _process(self, lst): + new_lst = [] + for tag in lst: + if tag.startswith("rating:"): + original_danbooru_tag = tag.split(":")[1] + else: + original_danbooru_tag = tag + if original_danbooru_tag == "safe": + continue + new_lst.append(original_danbooru_tag) + return new_lst + async def interrogate(self, session, path): async with session.post( f"{self.address}/", @@ -163,6 +188,7 @@ async def download_images(ctx): log.info("processing post %r", md5) existing_post = await fetch_post(ctx, md5) if existing_post: + log.info("already exists %r", md5) continue # download the post @@ -170,10 +196,10 @@ async def download_images(ctx): post_url_path = Path(urlparse(post_file_url).path) good_extension = False for extension in ALLOWED_EXTENSIONS: - if extension in post_url_path.stem: + if extension in post_url_path.suffix: good_extension = True if not good_extension: - log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.stem) + log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.suffix) continue post_filename = post_url_path.name post_filepath = DOWNLOADS / post_filename @@ -263,12 +289,19 @@ async def fight(ctx): ) -def score(danbooru_tags: Set[str], interrogator_tags: Set[str]) -> decimal.Decimal: +def score( + danbooru_tags: Set[str], interrogator_tags: Set[str] +) -> Tuple[decimal.Decimal, Set[str]]: tags_in_danbooru = danbooru_tags.intersection(interrogator_tags) tags_not_in_danbooru = interrogator_tags - danbooru_tags - return decimal.Decimal( - len(tags_in_danbooru) - len(tags_not_in_danbooru) - ) / decimal.Decimal(len(danbooru_tags)) + return ( + round( + decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru)) + / decimal.Decimal(len(danbooru_tags)), + 10, + ), + tags_not_in_danbooru, + ) async def scores(ctx): @@ -277,28 +310,36 @@ async def scores(ctx): all_rows = await ctx.db.execute_fetchall("select md5 from posts") all_hashes = set(r[0] for r in all_rows) - absolute_scores = defaultdict(decimal.Decimal) + # absolute_scores = defaultdict(decimal.Decimal) + model_scores = defaultdict(dict) + incorrect_tags_counters = defaultdict(Counter) for md5_hash in all_hashes: log.info("processing for %r", md5_hash) post = await fetch_post(ctx, md5_hash) danbooru_tags = set(post["tag_string"].split()) for interrogator in interrogators: - rows = await ctx.db.execute_fetchall( - "select output_tag_string from interrogated_posts where md5 = ? and model_name = ?", - (md5_hash, interrogator.model_id), - ) - assert len(rows) == 1 # run 'fight' mode, there's missing posts - tag_string = rows[0][0] - interrogator_tags = set(tag_string.split()) - absolute_scores[interrogator.model_id] += score( - danbooru_tags, interrogator_tags - ) + interrogator_tags = set(await interrogator.fetch_tags(ctx, md5_hash)) + tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags) + for tag in incorrect_tags: + incorrect_tags_counters[interrogator.model_id][tag] += 1 + + model_scores[interrogator.model_id][md5_hash] = { + "score": tagging_score, + "incorrect_tags": incorrect_tags, + } + + summed_scores = { + model_id: sum(d["score"] for d in post_scores.values()) + for model_id, post_scores in model_scores.items() + } normalized_scores = { - model: absolute_scores[model] / len(all_hashes) for model in absolute_scores + model: summed_scores[model] / len(all_hashes) for model in summed_scores } + print("scores are [worst...best]") + for model in sorted( normalized_scores.keys(), key=lambda model: normalized_scores[model], @@ -306,6 +347,30 @@ async def scores(ctx): ): print(model, normalized_scores[model]) + print("[", end="") + + for bad_md5_hash in sorted( + model_scores[model].keys(), + key=lambda md5_hash: model_scores[model][md5_hash]["score"], + )[:4]: + data = model_scores[model][bad_md5_hash] + if os.getenv("DEBUG", "0") == "1": + print(md5_hash, data["score"], " ".join(data["incorrect_tags"])) + else: + print(data["score"], end=",") + print("...", end="") + + for good_md5_hash in sorted( + model_scores[model].keys(), + key=lambda md5_hash: model_scores[model][md5_hash]["score"], + reverse=True, + )[:4]: + data = model_scores[model][good_md5_hash] + print(data["score"], end=",") + + print("]") + print("most incorrect tags", incorrect_tags_counters[model].most_common(5)) + async def realmain(ctx): await ctx.db.executescript(