diff --git a/.gitignore b/.gitignore index f7ab5d7..8adcec9 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,4 @@ cython_debug/ config.json data.db posts/ +plots/ diff --git a/main.py b/main.py index 6f36b3d..c5384d0 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,10 @@ import asyncio import aiosqlite import base64 import logging +import pandas as pd +import plotly.express as px +import plotly.graph_objs as go +import plotly.io as pio from collections import defaultdict, Counter from pathlib import Path from urllib.parse import urlparse @@ -404,6 +408,63 @@ async def scores(ctx): print("]") print("most incorrect tags", incorrect_tags_counters[model].most_common(5)) + PLOTS = Path.cwd() / "plots" + PLOTS.mkdir(exist_ok=True) + + data_for_df = {} + data_for_df["scores"] = [] + data_for_df["model"] = [] + + for model in sorted( + normalized_scores.keys(), + key=lambda model: normalized_scores[model], + reverse=True, + ): + for post_score in (d["score"] for d in model_scores[model].values()): + data_for_df["scores"].append(post_score) + data_for_df["model"].append(model) + + df = pd.DataFrame(data_for_df) + fig = px.histogram( + df, + x="scores", + color="model", + histfunc="count", + marginal="rug", + histnorm="probability", + ) + pio.write_image(fig, PLOTS / "score_histogram.png", width=1024, height=800) + + plot2(PLOTS / "positive_score_histogram.png", normalized_scores, model_scores) + + +def plot2(output_path, normalized_scores, model_scores): + data_for_df = {} + data_for_df["scores"] = [] + data_for_df["model"] = [] + + for model in sorted( + normalized_scores.keys(), + key=lambda model: normalized_scores[model], + reverse=True, + ): + for post_score in (d["score"] for d in model_scores[model].values()): + if post_score < 0: + continue + data_for_df["scores"].append(post_score) + data_for_df["model"].append(model) + + df = pd.DataFrame(data_for_df) + fig = px.histogram( + df, + x="scores", + color="model", + histfunc="count", + marginal="rug", + histnorm="probability", + ) + pio.write_image(fig, output_path, width=1024, height=800) + async def realmain(ctx): await ctx.db.executescript( diff --git a/requirements.txt b/requirements.txt index db0e4df..391859f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,8 @@ aiosqlite==0.19.0 aiohttp==3.8.4 -aiolimiter==1.1.0 \ No newline at end of file +aiolimiter>1.1.0<2.0 + + +plotly>5.15.0<6.0 +pandas==2.0.2 +kaleido==0.2.1 \ No newline at end of file