From d299e68526b0f659abbbc2d76dbe81f9b641d0aa Mon Sep 17 00:00:00 2001 From: Luna Date: Sat, 10 Jun 2023 16:51:50 -0300 Subject: [PATCH] add plotting for the error rate of taggers --- main.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/main.py b/main.py index 7ef1f73..83e5abe 100644 --- a/main.py +++ b/main.py @@ -411,6 +411,8 @@ async def scores(ctx): PLOTS = Path.cwd() / "plots" PLOTS.mkdir(exist_ok=True) + log.info("plotting score histogram...") + data_for_df = {} data_for_df["scores"] = [] data_for_df["model"] = [] @@ -435,7 +437,10 @@ async def scores(ctx): ) pio.write_image(fig, PLOTS / "score_histogram.png", width=1024, height=800) + log.info("plotting positive histogram...") plot2(PLOTS / "positive_score_histogram.png", normalized_scores, model_scores) + log.info("plotting error rates...") + plot3(PLOTS / "error_rate.png", normalized_scores, model_scores) def plot2(output_path, normalized_scores, model_scores): @@ -466,6 +471,39 @@ def plot2(output_path, normalized_scores, model_scores): pio.write_image(fig, output_path, width=1024, height=800) +def plot3(output_path, normalized_scores, model_scores): + data_for_df = {"model": [], "errors": [], "rating_errors": []} + + for model in sorted( + normalized_scores.keys(), + key=lambda model: normalized_scores[model], + reverse=True, + ): + total_incorrect_tags = 0 + total_rating_errors = 0 + for score_data in model_scores[model].values(): + total_incorrect_tags += len(score_data["incorrect_tags"]) + total_rating_errors += sum( + 1 + for rating in ["general", "sensitive", "questionable", "explicit"] + if rating in score_data["incorrect_tags"] + ) + + data_for_df["errors"].append(total_incorrect_tags) + data_for_df["rating_errors"].append(total_rating_errors) + data_for_df["model"].append(model) + + df = pd.DataFrame(data_for_df) + + fig = go.Figure( + data=[ + go.Bar(name="incorrect tags", x=df.model, y=df.errors), + go.Bar(name="incorrect ratings", x=df.model, y=df.rating_errors), + ] + ) + pio.write_image(fig, output_path, width=1024, height=800) + + async def realmain(ctx): await ctx.db.executescript( """