add histogram plotting of scores

This commit is contained in:
Luna 2023-06-10 15:26:02 -03:00
parent d8a4f6aaaf
commit 389a582e39
3 changed files with 68 additions and 1 deletions

1
.gitignore vendored
View file

@ -163,3 +163,4 @@ cython_debug/
config.json
data.db
posts/
plots/

61
main.py
View file

@ -10,6 +10,10 @@ import asyncio
import aiosqlite
import base64
import logging
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
import plotly.io as pio
from collections import defaultdict, Counter
from pathlib import Path
from urllib.parse import urlparse
@ -404,6 +408,63 @@ async def scores(ctx):
print("]")
print("most incorrect tags", incorrect_tags_counters[model].most_common(5))
PLOTS = Path.cwd() / "plots"
PLOTS.mkdir(exist_ok=True)
data_for_df = {}
data_for_df["scores"] = []
data_for_df["model"] = []
for model in sorted(
normalized_scores.keys(),
key=lambda model: normalized_scores[model],
reverse=True,
):
for post_score in (d["score"] for d in model_scores[model].values()):
data_for_df["scores"].append(post_score)
data_for_df["model"].append(model)
df = pd.DataFrame(data_for_df)
fig = px.histogram(
df,
x="scores",
color="model",
histfunc="count",
marginal="rug",
histnorm="probability",
)
pio.write_image(fig, PLOTS / "score_histogram.png", width=1024, height=800)
plot2(PLOTS / "positive_score_histogram.png", normalized_scores, model_scores)
def plot2(output_path, normalized_scores, model_scores):
data_for_df = {}
data_for_df["scores"] = []
data_for_df["model"] = []
for model in sorted(
normalized_scores.keys(),
key=lambda model: normalized_scores[model],
reverse=True,
):
for post_score in (d["score"] for d in model_scores[model].values()):
if post_score < 0:
continue
data_for_df["scores"].append(post_score)
data_for_df["model"].append(model)
df = pd.DataFrame(data_for_df)
fig = px.histogram(
df,
x="scores",
color="model",
histfunc="count",
marginal="rug",
histnorm="probability",
)
pio.write_image(fig, output_path, width=1024, height=800)
async def realmain(ctx):
await ctx.db.executescript(

View file

@ -1,3 +1,8 @@
aiosqlite==0.19.0
aiohttp==3.8.4
aiolimiter==1.1.0
aiolimiter>1.1.0<2.0
plotly>5.15.0<6.0
pandas==2.0.2
kaleido==0.2.1