Compare commits

...

4 Commits

Author SHA1 Message Date
Luna d299e68526 add plotting for the error rate of taggers 2023-06-10 16:51:50 -03:00
Luna c50bd85b7d map deepdanbooru's safe to general 2023-06-10 16:51:26 -03:00
Luna 389a582e39 add histogram plotting of scores 2023-06-10 15:26:02 -03:00
Luna d8a4f6aaaf skip latest pages
they probably dont have good ground truth to begin with
2023-06-10 15:25:45 -03:00
3 changed files with 116 additions and 6 deletions

1
.gitignore vendored
View File

@ -163,3 +163,4 @@ cython_debug/
config.json config.json
data.db data.db
posts/ posts/
plots/

114
main.py
View File

@ -10,6 +10,10 @@ import asyncio
import aiosqlite import aiosqlite
import base64 import base64
import logging import logging
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
import plotly.io as pio
from collections import defaultdict, Counter from collections import defaultdict, Counter
from pathlib import Path from pathlib import Path
from urllib.parse import urlparse from urllib.parse import urlparse
@ -63,7 +67,7 @@ class DDInterrogator(Interrogator):
else: else:
original_danbooru_tag = tag original_danbooru_tag = tag
if original_danbooru_tag == "safe": if original_danbooru_tag == "safe":
continue original_danbooru_tag = "general"
new_lst.append(original_danbooru_tag) new_lst.append(original_danbooru_tag)
return new_lst return new_lst
@ -157,13 +161,13 @@ class Danbooru(Booru):
title = "Danbooru" title = "Danbooru"
base_url = "https://danbooru.donmai.us" base_url = "https://danbooru.donmai.us"
async def posts(self, tag_query: str, limit): async def posts(self, tag_query: str, limit, page: int):
log.info("%s: submit %r", self.title, tag_query) log.info("%s: submit %r", self.title, tag_query)
async with self.limiter: async with self.limiter:
log.info("%s: submit upstream %r", self.title, tag_query) log.info("%s: submit upstream %r", self.title, tag_query)
async with self.session.get( async with self.session.get(
f"{self.base_url}/posts.json", f"{self.base_url}/posts.json",
params={"tags": tag_query, "limit": limit}, params={"tags": tag_query, "limit": limit, "page": page},
) as resp: ) as resp:
assert resp.status == 200 assert resp.status == 200
rjson = await resp.json() rjson = await resp.json()
@ -190,7 +194,12 @@ async def download_images(ctx):
try: try:
limit = int(sys.argv[3]) limit = int(sys.argv[3])
except IndexError: except IndexError:
limit = 10 limit = 30
try:
pageskip = int(sys.argv[4])
except IndexError:
pageskip = 150
danbooru = Danbooru( danbooru = Danbooru(
ctx.session, ctx.session,
@ -198,7 +207,7 @@ async def download_images(ctx):
AsyncLimiter(1, 3), AsyncLimiter(1, 3),
) )
posts = await danbooru.posts(tagquery, limit) posts = await danbooru.posts(tagquery, limit, pageskip)
for post in posts: for post in posts:
if "md5" not in post: if "md5" not in post:
continue continue
@ -399,6 +408,101 @@ async def scores(ctx):
print("]") print("]")
print("most incorrect tags", incorrect_tags_counters[model].most_common(5)) print("most incorrect tags", incorrect_tags_counters[model].most_common(5))
PLOTS = Path.cwd() / "plots"
PLOTS.mkdir(exist_ok=True)
log.info("plotting score histogram...")
data_for_df = {}
data_for_df["scores"] = []
data_for_df["model"] = []
for model in sorted(
normalized_scores.keys(),
key=lambda model: normalized_scores[model],
reverse=True,
):
for post_score in (d["score"] for d in model_scores[model].values()):
data_for_df["scores"].append(post_score)
data_for_df["model"].append(model)
df = pd.DataFrame(data_for_df)
fig = px.histogram(
df,
x="scores",
color="model",
histfunc="count",
marginal="rug",
histnorm="probability",
)
pio.write_image(fig, PLOTS / "score_histogram.png", width=1024, height=800)
log.info("plotting positive histogram...")
plot2(PLOTS / "positive_score_histogram.png", normalized_scores, model_scores)
log.info("plotting error rates...")
plot3(PLOTS / "error_rate.png", normalized_scores, model_scores)
def plot2(output_path, normalized_scores, model_scores):
data_for_df = {}
data_for_df["scores"] = []
data_for_df["model"] = []
for model in sorted(
normalized_scores.keys(),
key=lambda model: normalized_scores[model],
reverse=True,
):
for post_score in (d["score"] for d in model_scores[model].values()):
if post_score < 0:
continue
data_for_df["scores"].append(post_score)
data_for_df["model"].append(model)
df = pd.DataFrame(data_for_df)
fig = px.histogram(
df,
x="scores",
color="model",
histfunc="count",
marginal="rug",
histnorm="probability",
)
pio.write_image(fig, output_path, width=1024, height=800)
def plot3(output_path, normalized_scores, model_scores):
data_for_df = {"model": [], "errors": [], "rating_errors": []}
for model in sorted(
normalized_scores.keys(),
key=lambda model: normalized_scores[model],
reverse=True,
):
total_incorrect_tags = 0
total_rating_errors = 0
for score_data in model_scores[model].values():
total_incorrect_tags += len(score_data["incorrect_tags"])
total_rating_errors += sum(
1
for rating in ["general", "sensitive", "questionable", "explicit"]
if rating in score_data["incorrect_tags"]
)
data_for_df["errors"].append(total_incorrect_tags)
data_for_df["rating_errors"].append(total_rating_errors)
data_for_df["model"].append(model)
df = pd.DataFrame(data_for_df)
fig = go.Figure(
data=[
go.Bar(name="incorrect tags", x=df.model, y=df.errors),
go.Bar(name="incorrect ratings", x=df.model, y=df.rating_errors),
]
)
pio.write_image(fig, output_path, width=1024, height=800)
async def realmain(ctx): async def realmain(ctx):
await ctx.db.executescript( await ctx.db.executescript(

View File

@ -1,3 +1,8 @@
aiosqlite==0.19.0 aiosqlite==0.19.0
aiohttp==3.8.4 aiohttp==3.8.4
aiolimiter==1.1.0 aiolimiter>1.1.0<2.0
plotly>5.15.0<6.0
pandas==2.0.2
kaleido==0.2.1