Compare commits

...

4 Commits

Author SHA1 Message Date
Luna d299e68526 add plotting for the error rate of taggers 2023-06-10 16:51:50 -03:00
Luna c50bd85b7d map deepdanbooru's safe to general 2023-06-10 16:51:26 -03:00
Luna 389a582e39 add histogram plotting of scores 2023-06-10 15:26:02 -03:00
Luna d8a4f6aaaf skip latest pages
they probably dont have good ground truth to begin with
2023-06-10 15:25:45 -03:00
3 changed files with 116 additions and 6 deletions

1
.gitignore vendored
View File

@ -163,3 +163,4 @@ cython_debug/
config.json
data.db
posts/
plots/

114
main.py
View File

@ -10,6 +10,10 @@ import asyncio
import aiosqlite
import base64
import logging
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
import plotly.io as pio
from collections import defaultdict, Counter
from pathlib import Path
from urllib.parse import urlparse
@ -63,7 +67,7 @@ class DDInterrogator(Interrogator):
else:
original_danbooru_tag = tag
if original_danbooru_tag == "safe":
continue
original_danbooru_tag = "general"
new_lst.append(original_danbooru_tag)
return new_lst
@ -157,13 +161,13 @@ class Danbooru(Booru):
title = "Danbooru"
base_url = "https://danbooru.donmai.us"
async def posts(self, tag_query: str, limit):
async def posts(self, tag_query: str, limit, page: int):
log.info("%s: submit %r", self.title, tag_query)
async with self.limiter:
log.info("%s: submit upstream %r", self.title, tag_query)
async with self.session.get(
f"{self.base_url}/posts.json",
params={"tags": tag_query, "limit": limit},
params={"tags": tag_query, "limit": limit, "page": page},
) as resp:
assert resp.status == 200
rjson = await resp.json()
@ -190,7 +194,12 @@ async def download_images(ctx):
try:
limit = int(sys.argv[3])
except IndexError:
limit = 10
limit = 30
try:
pageskip = int(sys.argv[4])
except IndexError:
pageskip = 150
danbooru = Danbooru(
ctx.session,
@ -198,7 +207,7 @@ async def download_images(ctx):
AsyncLimiter(1, 3),
)
posts = await danbooru.posts(tagquery, limit)
posts = await danbooru.posts(tagquery, limit, pageskip)
for post in posts:
if "md5" not in post:
continue
@ -399,6 +408,101 @@ async def scores(ctx):
print("]")
print("most incorrect tags", incorrect_tags_counters[model].most_common(5))
PLOTS = Path.cwd() / "plots"
PLOTS.mkdir(exist_ok=True)
log.info("plotting score histogram...")
data_for_df = {}
data_for_df["scores"] = []
data_for_df["model"] = []
for model in sorted(
normalized_scores.keys(),
key=lambda model: normalized_scores[model],
reverse=True,
):
for post_score in (d["score"] for d in model_scores[model].values()):
data_for_df["scores"].append(post_score)
data_for_df["model"].append(model)
df = pd.DataFrame(data_for_df)
fig = px.histogram(
df,
x="scores",
color="model",
histfunc="count",
marginal="rug",
histnorm="probability",
)
pio.write_image(fig, PLOTS / "score_histogram.png", width=1024, height=800)
log.info("plotting positive histogram...")
plot2(PLOTS / "positive_score_histogram.png", normalized_scores, model_scores)
log.info("plotting error rates...")
plot3(PLOTS / "error_rate.png", normalized_scores, model_scores)
def plot2(output_path, normalized_scores, model_scores):
data_for_df = {}
data_for_df["scores"] = []
data_for_df["model"] = []
for model in sorted(
normalized_scores.keys(),
key=lambda model: normalized_scores[model],
reverse=True,
):
for post_score in (d["score"] for d in model_scores[model].values()):
if post_score < 0:
continue
data_for_df["scores"].append(post_score)
data_for_df["model"].append(model)
df = pd.DataFrame(data_for_df)
fig = px.histogram(
df,
x="scores",
color="model",
histfunc="count",
marginal="rug",
histnorm="probability",
)
pio.write_image(fig, output_path, width=1024, height=800)
def plot3(output_path, normalized_scores, model_scores):
data_for_df = {"model": [], "errors": [], "rating_errors": []}
for model in sorted(
normalized_scores.keys(),
key=lambda model: normalized_scores[model],
reverse=True,
):
total_incorrect_tags = 0
total_rating_errors = 0
for score_data in model_scores[model].values():
total_incorrect_tags += len(score_data["incorrect_tags"])
total_rating_errors += sum(
1
for rating in ["general", "sensitive", "questionable", "explicit"]
if rating in score_data["incorrect_tags"]
)
data_for_df["errors"].append(total_incorrect_tags)
data_for_df["rating_errors"].append(total_rating_errors)
data_for_df["model"].append(model)
df = pd.DataFrame(data_for_df)
fig = go.Figure(
data=[
go.Bar(name="incorrect tags", x=df.model, y=df.errors),
go.Bar(name="incorrect ratings", x=df.model, y=df.rating_errors),
]
)
pio.write_image(fig, output_path, width=1024, height=800)
async def realmain(ctx):
await ctx.db.executescript(

View File

@ -1,3 +1,8 @@
aiosqlite==0.19.0
aiohttp==3.8.4
aiolimiter==1.1.0
aiolimiter>1.1.0<2.0
plotly>5.15.0<6.0
pandas==2.0.2
kaleido==0.2.1