add some basic analysis about final scores

This commit is contained in:
Luna 2023-06-10 01:30:48 -03:00
parent f3ec3ab6d3
commit 099c7f8fc9
1 changed files with 55 additions and 16 deletions

71
main.py
View File

@ -289,12 +289,19 @@ async def fight(ctx):
)
def score(danbooru_tags: Set[str], interrogator_tags: Set[str]) -> decimal.Decimal:
def score(
danbooru_tags: Set[str], interrogator_tags: Set[str]
) -> Tuple[decimal.Decimal, Set[str]]:
tags_in_danbooru = danbooru_tags.intersection(interrogator_tags)
tags_not_in_danbooru = interrogator_tags - danbooru_tags
return decimal.Decimal(
len(tags_in_danbooru) - len(tags_not_in_danbooru)
) / decimal.Decimal(len(danbooru_tags))
return (
round(
decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru))
/ decimal.Decimal(len(danbooru_tags)),
10,
),
tags_not_in_danbooru,
)
async def scores(ctx):
@ -303,28 +310,36 @@ async def scores(ctx):
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
all_hashes = set(r[0] for r in all_rows)
absolute_scores = defaultdict(decimal.Decimal)
# absolute_scores = defaultdict(decimal.Decimal)
model_scores = defaultdict(dict)
incorrect_tags_counters = defaultdict(Counter)
for md5_hash in all_hashes:
log.info("processing for %r", md5_hash)
post = await fetch_post(ctx, md5_hash)
danbooru_tags = set(post["tag_string"].split())
for interrogator in interrogators:
rows = await ctx.db.execute_fetchall(
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
(md5_hash, interrogator.model_id),
)
assert len(rows) == 1 # run 'fight' mode, there's missing posts
tag_string = rows[0][0]
interrogator_tags = set(tag_string.split())
absolute_scores[interrogator.model_id] += score(
danbooru_tags, interrogator_tags
)
interrogator_tags = set(await interrogator.fetch_tags(ctx, md5_hash))
tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags)
for tag in incorrect_tags:
incorrect_tags_counters[interrogator.model_id][tag] += 1
model_scores[interrogator.model_id][md5_hash] = {
"score": tagging_score,
"incorrect_tags": incorrect_tags,
}
summed_scores = {
model_id: sum(d["score"] for d in post_scores.values())
for model_id, post_scores in model_scores.items()
}
normalized_scores = {
model: absolute_scores[model] / len(all_hashes) for model in absolute_scores
model: summed_scores[model] / len(all_hashes) for model in summed_scores
}
print("scores are [worst...best]")
for model in sorted(
normalized_scores.keys(),
key=lambda model: normalized_scores[model],
@ -332,6 +347,30 @@ async def scores(ctx):
):
print(model, normalized_scores[model])
print("[", end="")
for bad_md5_hash in sorted(
model_scores[model].keys(),
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
)[:4]:
data = model_scores[model][bad_md5_hash]
if os.getenv("DEBUG", "0") == "1":
print(md5_hash, data["score"], " ".join(data["incorrect_tags"]))
else:
print(data["score"], end=",")
print("...", end="")
for good_md5_hash in sorted(
model_scores[model].keys(),
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
reverse=True,
)[:4]:
data = model_scores[model][good_md5_hash]
print(data["score"], end=",")
print("]")
print("most incorrect tags", incorrect_tags_counters[model].most_common(5))
async def realmain(ctx):
await ctx.db.executescript(