add some basic analysis about final scores
This commit is contained in:
parent
f3ec3ab6d3
commit
099c7f8fc9
71
main.py
71
main.py
|
@ -289,12 +289,19 @@ async def fight(ctx):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def score(danbooru_tags: Set[str], interrogator_tags: Set[str]) -> decimal.Decimal:
|
def score(
|
||||||
|
danbooru_tags: Set[str], interrogator_tags: Set[str]
|
||||||
|
) -> Tuple[decimal.Decimal, Set[str]]:
|
||||||
tags_in_danbooru = danbooru_tags.intersection(interrogator_tags)
|
tags_in_danbooru = danbooru_tags.intersection(interrogator_tags)
|
||||||
tags_not_in_danbooru = interrogator_tags - danbooru_tags
|
tags_not_in_danbooru = interrogator_tags - danbooru_tags
|
||||||
return decimal.Decimal(
|
return (
|
||||||
len(tags_in_danbooru) - len(tags_not_in_danbooru)
|
round(
|
||||||
) / decimal.Decimal(len(danbooru_tags))
|
decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru))
|
||||||
|
/ decimal.Decimal(len(danbooru_tags)),
|
||||||
|
10,
|
||||||
|
),
|
||||||
|
tags_not_in_danbooru,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def scores(ctx):
|
async def scores(ctx):
|
||||||
|
@ -303,28 +310,36 @@ async def scores(ctx):
|
||||||
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
|
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
|
||||||
all_hashes = set(r[0] for r in all_rows)
|
all_hashes = set(r[0] for r in all_rows)
|
||||||
|
|
||||||
absolute_scores = defaultdict(decimal.Decimal)
|
# absolute_scores = defaultdict(decimal.Decimal)
|
||||||
|
model_scores = defaultdict(dict)
|
||||||
|
incorrect_tags_counters = defaultdict(Counter)
|
||||||
|
|
||||||
for md5_hash in all_hashes:
|
for md5_hash in all_hashes:
|
||||||
log.info("processing for %r", md5_hash)
|
log.info("processing for %r", md5_hash)
|
||||||
post = await fetch_post(ctx, md5_hash)
|
post = await fetch_post(ctx, md5_hash)
|
||||||
danbooru_tags = set(post["tag_string"].split())
|
danbooru_tags = set(post["tag_string"].split())
|
||||||
for interrogator in interrogators:
|
for interrogator in interrogators:
|
||||||
rows = await ctx.db.execute_fetchall(
|
interrogator_tags = set(await interrogator.fetch_tags(ctx, md5_hash))
|
||||||
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
|
tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags)
|
||||||
(md5_hash, interrogator.model_id),
|
for tag in incorrect_tags:
|
||||||
)
|
incorrect_tags_counters[interrogator.model_id][tag] += 1
|
||||||
assert len(rows) == 1 # run 'fight' mode, there's missing posts
|
|
||||||
tag_string = rows[0][0]
|
model_scores[interrogator.model_id][md5_hash] = {
|
||||||
interrogator_tags = set(tag_string.split())
|
"score": tagging_score,
|
||||||
absolute_scores[interrogator.model_id] += score(
|
"incorrect_tags": incorrect_tags,
|
||||||
danbooru_tags, interrogator_tags
|
}
|
||||||
)
|
|
||||||
|
summed_scores = {
|
||||||
|
model_id: sum(d["score"] for d in post_scores.values())
|
||||||
|
for model_id, post_scores in model_scores.items()
|
||||||
|
}
|
||||||
|
|
||||||
normalized_scores = {
|
normalized_scores = {
|
||||||
model: absolute_scores[model] / len(all_hashes) for model in absolute_scores
|
model: summed_scores[model] / len(all_hashes) for model in summed_scores
|
||||||
}
|
}
|
||||||
|
|
||||||
|
print("scores are [worst...best]")
|
||||||
|
|
||||||
for model in sorted(
|
for model in sorted(
|
||||||
normalized_scores.keys(),
|
normalized_scores.keys(),
|
||||||
key=lambda model: normalized_scores[model],
|
key=lambda model: normalized_scores[model],
|
||||||
|
@ -332,6 +347,30 @@ async def scores(ctx):
|
||||||
):
|
):
|
||||||
print(model, normalized_scores[model])
|
print(model, normalized_scores[model])
|
||||||
|
|
||||||
|
print("[", end="")
|
||||||
|
|
||||||
|
for bad_md5_hash in sorted(
|
||||||
|
model_scores[model].keys(),
|
||||||
|
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
|
||||||
|
)[:4]:
|
||||||
|
data = model_scores[model][bad_md5_hash]
|
||||||
|
if os.getenv("DEBUG", "0") == "1":
|
||||||
|
print(md5_hash, data["score"], " ".join(data["incorrect_tags"]))
|
||||||
|
else:
|
||||||
|
print(data["score"], end=",")
|
||||||
|
print("...", end="")
|
||||||
|
|
||||||
|
for good_md5_hash in sorted(
|
||||||
|
model_scores[model].keys(),
|
||||||
|
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
|
||||||
|
reverse=True,
|
||||||
|
)[:4]:
|
||||||
|
data = model_scores[model][good_md5_hash]
|
||||||
|
print(data["score"], end=",")
|
||||||
|
|
||||||
|
print("]")
|
||||||
|
print("most incorrect tags", incorrect_tags_counters[model].most_common(5))
|
||||||
|
|
||||||
|
|
||||||
async def realmain(ctx):
|
async def realmain(ctx):
|
||||||
await ctx.db.executescript(
|
await ctx.db.executescript(
|
||||||
|
|
Loading…
Reference in New Issue