add plotting for the error rate of taggers
This commit is contained in:
parent
c50bd85b7d
commit
d299e68526
1 changed files with 38 additions and 0 deletions
38
main.py
38
main.py
|
@ -411,6 +411,8 @@ async def scores(ctx):
|
|||
PLOTS = Path.cwd() / "plots"
|
||||
PLOTS.mkdir(exist_ok=True)
|
||||
|
||||
log.info("plotting score histogram...")
|
||||
|
||||
data_for_df = {}
|
||||
data_for_df["scores"] = []
|
||||
data_for_df["model"] = []
|
||||
|
@ -435,7 +437,10 @@ async def scores(ctx):
|
|||
)
|
||||
pio.write_image(fig, PLOTS / "score_histogram.png", width=1024, height=800)
|
||||
|
||||
log.info("plotting positive histogram...")
|
||||
plot2(PLOTS / "positive_score_histogram.png", normalized_scores, model_scores)
|
||||
log.info("plotting error rates...")
|
||||
plot3(PLOTS / "error_rate.png", normalized_scores, model_scores)
|
||||
|
||||
|
||||
def plot2(output_path, normalized_scores, model_scores):
|
||||
|
@ -466,6 +471,39 @@ def plot2(output_path, normalized_scores, model_scores):
|
|||
pio.write_image(fig, output_path, width=1024, height=800)
|
||||
|
||||
|
||||
def plot3(output_path, normalized_scores, model_scores):
|
||||
data_for_df = {"model": [], "errors": [], "rating_errors": []}
|
||||
|
||||
for model in sorted(
|
||||
normalized_scores.keys(),
|
||||
key=lambda model: normalized_scores[model],
|
||||
reverse=True,
|
||||
):
|
||||
total_incorrect_tags = 0
|
||||
total_rating_errors = 0
|
||||
for score_data in model_scores[model].values():
|
||||
total_incorrect_tags += len(score_data["incorrect_tags"])
|
||||
total_rating_errors += sum(
|
||||
1
|
||||
for rating in ["general", "sensitive", "questionable", "explicit"]
|
||||
if rating in score_data["incorrect_tags"]
|
||||
)
|
||||
|
||||
data_for_df["errors"].append(total_incorrect_tags)
|
||||
data_for_df["rating_errors"].append(total_rating_errors)
|
||||
data_for_df["model"].append(model)
|
||||
|
||||
df = pd.DataFrame(data_for_df)
|
||||
|
||||
fig = go.Figure(
|
||||
data=[
|
||||
go.Bar(name="incorrect tags", x=df.model, y=df.errors),
|
||||
go.Bar(name="incorrect ratings", x=df.model, y=df.rating_errors),
|
||||
]
|
||||
)
|
||||
pio.write_image(fig, output_path, width=1024, height=800)
|
||||
|
||||
|
||||
async def realmain(ctx):
|
||||
await ctx.db.executescript(
|
||||
"""
|
||||
|
|
Loading…
Reference in a new issue