use f1 scores

This commit is contained in:
Luna 2025-04-12 00:48:01 -03:00
parent a73a3f26ce
commit c721e5ee31

109
main.py
View file

@ -398,14 +398,42 @@ async def fight(ctx):
def score(
danbooru_tags: Set[str], interrogator_tags: Set[str]
) -> Tuple[decimal.Decimal, Set[str]]:
tags_in_danbooru = danbooru_tags.intersection(interrogator_tags)
f1 = None
# Handle edge cases
if len(danbooru_tags) == 0 and len(interrogator_tags) == 0:
f1 = decimal.Decimal("1.0") # Both empty means perfect match
if len(danbooru_tags) == 0 or len(interrogator_tags) == 0:
f1 = decimal.Decimal("0.0") # One empty means no match
# Calculate true positives (tags that appear in both sets)
true_positives = decimal.Decimal(len(danbooru_tags.intersection(interrogator_tags)))
# Calculate precision: TP / (TP + FP)
precision = (
true_positives / len(interrogator_tags)
if len(interrogator_tags) > 0
else decimal.Decimal("0.0")
)
# Calculate recall: TP / (TP + FN)
recall = (
true_positives / len(danbooru_tags)
if len(danbooru_tags) > 0
else decimal.Decimal("0.0")
)
print("recall", recall)
# Handle the case where both precision and recall are 0
if f1 is None and precision == 0 and recall == 0:
f1 = decimal.Decimal("0.0")
else:
f1 = decimal.Decimal("2.0") * (precision * recall) / (precision + recall)
tags_not_in_danbooru = interrogator_tags - danbooru_tags
return (
round(
decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru))
/ decimal.Decimal(len(danbooru_tags)),
10,
),
round(f1, 10),
tags_not_in_danbooru,
)
@ -420,9 +448,10 @@ async def scores(ctx):
model_scores = defaultdict(dict)
runtimes = defaultdict(list)
incorrect_tags_counters = defaultdict(Counter)
predicted_tags_counter = defaultdict(int)
for md5_hash in all_hashes:
log.info("processing for %r", md5_hash)
log.info("processing score for %r", md5_hash)
post = await fetch_post(ctx, md5_hash)
danbooru_tags = set(post["tag_string"].split())
for interrogator in interrogators:
@ -433,9 +462,14 @@ async def scores(ctx):
for tag in incorrect_tags:
incorrect_tags_counters[interrogator.model_id][tag] += 1
log.info(f"{interrogator.model_id} {tagging_score}")
predicted_tags_counter[interrogator.model_id] += len(interrogator_tags)
correct_tags = interrogator_tags.intersection(danbooru_tags)
model_scores[interrogator.model_id][md5_hash] = {
"score": tagging_score,
"predicted_tags": interrogator_tags,
"incorrect_tags": incorrect_tags,
"correct_tags": correct_tags,
}
summed_scores = {
@ -480,8 +514,19 @@ async def scores(ctx):
print(data["score"], end=",")
print("]")
print("most incorrect tags", incorrect_tags_counters[model].most_common(5))
total_incorrect = 0
for _, c in incorrect_tags_counters[model].most_common(10000000):
total_incorrect += c
print(
"most incorrect tags from",
total_incorrect,
"incorrect tags",
"predicted",
predicted_tags_counter[model],
"tags",
)
for t, c in incorrect_tags_counters[model].most_common(7):
print("\t", t, c)
PLOTS = Path.cwd() / "plots"
PLOTS.mkdir(exist_ok=True)
@ -514,7 +559,12 @@ async def scores(ctx):
log.info("plotting positive histogram...")
plot2(PLOTS / "positive_score_histogram.png", normalized_scores, model_scores)
log.info("plotting error rates...")
plot3(PLOTS / "error_rate.png", normalized_scores, model_scores)
plot3(
PLOTS / "error_rate.png",
PLOTS / "score_avg.png",
normalized_scores,
model_scores,
)
def plot2(output_path, normalized_scores, model_scores):
@ -545,12 +595,13 @@ def plot2(output_path, normalized_scores, model_scores):
pio.write_image(fig, output_path, width=1024, height=800)
def plot3(output_path, normalized_scores, model_scores):
def plot3(output_path, output_score_avg_path, normalized_scores, model_scores):
data_for_df = {
"model": [],
"errors": [],
"rating_errors": [],
"practical_errors": [],
"score_avg": [],
"predicted": [],
"correct": [],
"incorrect": [],
}
for model in sorted(
@ -558,32 +609,34 @@ def plot3(output_path, normalized_scores, model_scores):
key=lambda model: normalized_scores[model],
reverse=True,
):
total_incorrect_tags = 0
total_rating_errors = 0
total_predicted_tags, total_incorrect_tags, total_correct_tags = 0, 0, 0
for score_data in model_scores[model].values():
total_predicted_tags += len(score_data["predicted_tags"])
total_incorrect_tags += len(score_data["incorrect_tags"])
total_rating_errors += sum(
1
for rating in ["general", "sensitive", "questionable", "explicit"]
if rating in score_data["incorrect_tags"]
)
practical_absolute_error = total_incorrect_tags - total_rating_errors
total_correct_tags += len(score_data["correct_tags"])
data_for_df["errors"].append(total_incorrect_tags)
data_for_df["rating_errors"].append(total_rating_errors)
data_for_df["practical_errors"].append(practical_absolute_error)
data_for_df["score_avg"].append(normalized_scores[model])
data_for_df["predicted"].append(total_predicted_tags)
data_for_df["incorrect"].append(total_incorrect_tags)
data_for_df["correct"].append(total_correct_tags)
data_for_df["model"].append(model)
df = pd.DataFrame(data_for_df)
fig = go.Figure(
data=[
go.Bar(name="incorrect tags", x=df.model, y=df.errors),
go.Bar(name="incorrect ratings", x=df.model, y=df.rating_errors),
go.Bar(name="practical error", x=df.model, y=df.practical_errors),
go.Bar(name="predicted tags", x=df.model, y=df.predicted),
go.Bar(name="incorrect tags", x=df.model, y=df.incorrect),
go.Bar(name="correct tags", x=df.model, y=df.correct),
]
)
pio.write_image(fig, output_path, width=1024, height=800)
fig2 = go.Figure(
data=[
go.Bar(name="score avg", x=df.model, y=df.score_avg),
]
)
pio.write_image(fig2, output_score_avg_path, width=1024, height=800)
async def realmain(ctx):