From 099c7f8fc9b31cc3c2de4a832ab745c46a74ca2d Mon Sep 17 00:00:00 2001 From: Luna Date: Sat, 10 Jun 2023 01:30:48 -0300 Subject: [PATCH] add some basic analysis about final scores --- main.py | 71 ++++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 55 insertions(+), 16 deletions(-) diff --git a/main.py b/main.py index b08e821..ed9ca39 100644 --- a/main.py +++ b/main.py @@ -289,12 +289,19 @@ async def fight(ctx): ) -def score(danbooru_tags: Set[str], interrogator_tags: Set[str]) -> decimal.Decimal: +def score( + danbooru_tags: Set[str], interrogator_tags: Set[str] +) -> Tuple[decimal.Decimal, Set[str]]: tags_in_danbooru = danbooru_tags.intersection(interrogator_tags) tags_not_in_danbooru = interrogator_tags - danbooru_tags - return decimal.Decimal( - len(tags_in_danbooru) - len(tags_not_in_danbooru) - ) / decimal.Decimal(len(danbooru_tags)) + return ( + round( + decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru)) + / decimal.Decimal(len(danbooru_tags)), + 10, + ), + tags_not_in_danbooru, + ) async def scores(ctx): @@ -303,28 +310,36 @@ async def scores(ctx): all_rows = await ctx.db.execute_fetchall("select md5 from posts") all_hashes = set(r[0] for r in all_rows) - absolute_scores = defaultdict(decimal.Decimal) + # absolute_scores = defaultdict(decimal.Decimal) + model_scores = defaultdict(dict) + incorrect_tags_counters = defaultdict(Counter) for md5_hash in all_hashes: log.info("processing for %r", md5_hash) post = await fetch_post(ctx, md5_hash) danbooru_tags = set(post["tag_string"].split()) for interrogator in interrogators: - rows = await ctx.db.execute_fetchall( - "select output_tag_string from interrogated_posts where md5 = ? and model_name = ?", - (md5_hash, interrogator.model_id), - ) - assert len(rows) == 1 # run 'fight' mode, there's missing posts - tag_string = rows[0][0] - interrogator_tags = set(tag_string.split()) - absolute_scores[interrogator.model_id] += score( - danbooru_tags, interrogator_tags - ) + interrogator_tags = set(await interrogator.fetch_tags(ctx, md5_hash)) + tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags) + for tag in incorrect_tags: + incorrect_tags_counters[interrogator.model_id][tag] += 1 + + model_scores[interrogator.model_id][md5_hash] = { + "score": tagging_score, + "incorrect_tags": incorrect_tags, + } + + summed_scores = { + model_id: sum(d["score"] for d in post_scores.values()) + for model_id, post_scores in model_scores.items() + } normalized_scores = { - model: absolute_scores[model] / len(all_hashes) for model in absolute_scores + model: summed_scores[model] / len(all_hashes) for model in summed_scores } + print("scores are [worst...best]") + for model in sorted( normalized_scores.keys(), key=lambda model: normalized_scores[model], @@ -332,6 +347,30 @@ async def scores(ctx): ): print(model, normalized_scores[model]) + print("[", end="") + + for bad_md5_hash in sorted( + model_scores[model].keys(), + key=lambda md5_hash: model_scores[model][md5_hash]["score"], + )[:4]: + data = model_scores[model][bad_md5_hash] + if os.getenv("DEBUG", "0") == "1": + print(md5_hash, data["score"], " ".join(data["incorrect_tags"])) + else: + print(data["score"], end=",") + print("...", end="") + + for good_md5_hash in sorted( + model_scores[model].keys(), + key=lambda md5_hash: model_scores[model][md5_hash]["score"], + reverse=True, + )[:4]: + data = model_scores[model][good_md5_hash] + print(data["score"], end=",") + + print("]") + print("most incorrect tags", incorrect_tags_counters[model].most_common(5)) + async def realmain(ctx): await ctx.db.executescript(