add some basic analysis about final scores
This commit is contained in:
parent
f3ec3ab6d3
commit
099c7f8fc9
1 changed files with 55 additions and 16 deletions
71
main.py
71
main.py
|
@ -289,12 +289,19 @@ async def fight(ctx):
|
|||
)
|
||||
|
||||
|
||||
def score(danbooru_tags: Set[str], interrogator_tags: Set[str]) -> decimal.Decimal:
|
||||
def score(
|
||||
danbooru_tags: Set[str], interrogator_tags: Set[str]
|
||||
) -> Tuple[decimal.Decimal, Set[str]]:
|
||||
tags_in_danbooru = danbooru_tags.intersection(interrogator_tags)
|
||||
tags_not_in_danbooru = interrogator_tags - danbooru_tags
|
||||
return decimal.Decimal(
|
||||
len(tags_in_danbooru) - len(tags_not_in_danbooru)
|
||||
) / decimal.Decimal(len(danbooru_tags))
|
||||
return (
|
||||
round(
|
||||
decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru))
|
||||
/ decimal.Decimal(len(danbooru_tags)),
|
||||
10,
|
||||
),
|
||||
tags_not_in_danbooru,
|
||||
)
|
||||
|
||||
|
||||
async def scores(ctx):
|
||||
|
@ -303,28 +310,36 @@ async def scores(ctx):
|
|||
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
|
||||
all_hashes = set(r[0] for r in all_rows)
|
||||
|
||||
absolute_scores = defaultdict(decimal.Decimal)
|
||||
# absolute_scores = defaultdict(decimal.Decimal)
|
||||
model_scores = defaultdict(dict)
|
||||
incorrect_tags_counters = defaultdict(Counter)
|
||||
|
||||
for md5_hash in all_hashes:
|
||||
log.info("processing for %r", md5_hash)
|
||||
post = await fetch_post(ctx, md5_hash)
|
||||
danbooru_tags = set(post["tag_string"].split())
|
||||
for interrogator in interrogators:
|
||||
rows = await ctx.db.execute_fetchall(
|
||||
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
|
||||
(md5_hash, interrogator.model_id),
|
||||
)
|
||||
assert len(rows) == 1 # run 'fight' mode, there's missing posts
|
||||
tag_string = rows[0][0]
|
||||
interrogator_tags = set(tag_string.split())
|
||||
absolute_scores[interrogator.model_id] += score(
|
||||
danbooru_tags, interrogator_tags
|
||||
)
|
||||
interrogator_tags = set(await interrogator.fetch_tags(ctx, md5_hash))
|
||||
tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags)
|
||||
for tag in incorrect_tags:
|
||||
incorrect_tags_counters[interrogator.model_id][tag] += 1
|
||||
|
||||
model_scores[interrogator.model_id][md5_hash] = {
|
||||
"score": tagging_score,
|
||||
"incorrect_tags": incorrect_tags,
|
||||
}
|
||||
|
||||
summed_scores = {
|
||||
model_id: sum(d["score"] for d in post_scores.values())
|
||||
for model_id, post_scores in model_scores.items()
|
||||
}
|
||||
|
||||
normalized_scores = {
|
||||
model: absolute_scores[model] / len(all_hashes) for model in absolute_scores
|
||||
model: summed_scores[model] / len(all_hashes) for model in summed_scores
|
||||
}
|
||||
|
||||
print("scores are [worst...best]")
|
||||
|
||||
for model in sorted(
|
||||
normalized_scores.keys(),
|
||||
key=lambda model: normalized_scores[model],
|
||||
|
@ -332,6 +347,30 @@ async def scores(ctx):
|
|||
):
|
||||
print(model, normalized_scores[model])
|
||||
|
||||
print("[", end="")
|
||||
|
||||
for bad_md5_hash in sorted(
|
||||
model_scores[model].keys(),
|
||||
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
|
||||
)[:4]:
|
||||
data = model_scores[model][bad_md5_hash]
|
||||
if os.getenv("DEBUG", "0") == "1":
|
||||
print(md5_hash, data["score"], " ".join(data["incorrect_tags"]))
|
||||
else:
|
||||
print(data["score"], end=",")
|
||||
print("...", end="")
|
||||
|
||||
for good_md5_hash in sorted(
|
||||
model_scores[model].keys(),
|
||||
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
|
||||
reverse=True,
|
||||
)[:4]:
|
||||
data = model_scores[model][good_md5_hash]
|
||||
print(data["score"], end=",")
|
||||
|
||||
print("]")
|
||||
print("most incorrect tags", incorrect_tags_counters[model].most_common(5))
|
||||
|
||||
|
||||
async def realmain(ctx):
|
||||
await ctx.db.executescript(
|
||||
|
|
Loading…
Reference in a new issue