From d8a4f6aaaf03f3c6ee900a0775e05f3a8daba32e Mon Sep 17 00:00:00 2001 From: Luna Date: Sat, 10 Jun 2023 15:25:45 -0300 Subject: [PATCH 1/4] skip latest pages they probably dont have good ground truth to begin with --- main.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index a6f0074..6f36b3d 100644 --- a/main.py +++ b/main.py @@ -157,13 +157,13 @@ class Danbooru(Booru): title = "Danbooru" base_url = "https://danbooru.donmai.us" - async def posts(self, tag_query: str, limit): + async def posts(self, tag_query: str, limit, page: int): log.info("%s: submit %r", self.title, tag_query) async with self.limiter: log.info("%s: submit upstream %r", self.title, tag_query) async with self.session.get( f"{self.base_url}/posts.json", - params={"tags": tag_query, "limit": limit}, + params={"tags": tag_query, "limit": limit, "page": page}, ) as resp: assert resp.status == 200 rjson = await resp.json() @@ -190,7 +190,12 @@ async def download_images(ctx): try: limit = int(sys.argv[3]) except IndexError: - limit = 10 + limit = 30 + + try: + pageskip = int(sys.argv[4]) + except IndexError: + pageskip = 150 danbooru = Danbooru( ctx.session, @@ -198,7 +203,7 @@ async def download_images(ctx): AsyncLimiter(1, 3), ) - posts = await danbooru.posts(tagquery, limit) + posts = await danbooru.posts(tagquery, limit, pageskip) for post in posts: if "md5" not in post: continue From 389a582e39139c8e235412018ceec835aca39664 Mon Sep 17 00:00:00 2001 From: Luna Date: Sat, 10 Jun 2023 15:26:02 -0300 Subject: [PATCH 2/4] add histogram plotting of scores --- .gitignore | 1 + main.py | 61 ++++++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 7 +++++- 3 files changed, 68 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index f7ab5d7..8adcec9 100644 --- a/.gitignore +++ b/.gitignore @@ -163,3 +163,4 @@ cython_debug/ config.json data.db posts/ +plots/ diff --git a/main.py b/main.py index 6f36b3d..c5384d0 100644 --- a/main.py +++ b/main.py @@ -10,6 +10,10 @@ import asyncio import aiosqlite import base64 import logging +import pandas as pd +import plotly.express as px +import plotly.graph_objs as go +import plotly.io as pio from collections import defaultdict, Counter from pathlib import Path from urllib.parse import urlparse @@ -404,6 +408,63 @@ async def scores(ctx): print("]") print("most incorrect tags", incorrect_tags_counters[model].most_common(5)) + PLOTS = Path.cwd() / "plots" + PLOTS.mkdir(exist_ok=True) + + data_for_df = {} + data_for_df["scores"] = [] + data_for_df["model"] = [] + + for model in sorted( + normalized_scores.keys(), + key=lambda model: normalized_scores[model], + reverse=True, + ): + for post_score in (d["score"] for d in model_scores[model].values()): + data_for_df["scores"].append(post_score) + data_for_df["model"].append(model) + + df = pd.DataFrame(data_for_df) + fig = px.histogram( + df, + x="scores", + color="model", + histfunc="count", + marginal="rug", + histnorm="probability", + ) + pio.write_image(fig, PLOTS / "score_histogram.png", width=1024, height=800) + + plot2(PLOTS / "positive_score_histogram.png", normalized_scores, model_scores) + + +def plot2(output_path, normalized_scores, model_scores): + data_for_df = {} + data_for_df["scores"] = [] + data_for_df["model"] = [] + + for model in sorted( + normalized_scores.keys(), + key=lambda model: normalized_scores[model], + reverse=True, + ): + for post_score in (d["score"] for d in model_scores[model].values()): + if post_score < 0: + continue + data_for_df["scores"].append(post_score) + data_for_df["model"].append(model) + + df = pd.DataFrame(data_for_df) + fig = px.histogram( + df, + x="scores", + color="model", + histfunc="count", + marginal="rug", + histnorm="probability", + ) + pio.write_image(fig, output_path, width=1024, height=800) + async def realmain(ctx): await ctx.db.executescript( diff --git a/requirements.txt b/requirements.txt index db0e4df..391859f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,8 @@ aiosqlite==0.19.0 aiohttp==3.8.4 -aiolimiter==1.1.0 \ No newline at end of file +aiolimiter>1.1.0<2.0 + + +plotly>5.15.0<6.0 +pandas==2.0.2 +kaleido==0.2.1 \ No newline at end of file From c50bd85b7dbc61750783514c7866ef3179b1beaa Mon Sep 17 00:00:00 2001 From: Luna Date: Sat, 10 Jun 2023 16:51:26 -0300 Subject: [PATCH 3/4] map deepdanbooru's safe to general --- main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.py b/main.py index c5384d0..7ef1f73 100644 --- a/main.py +++ b/main.py @@ -67,7 +67,7 @@ class DDInterrogator(Interrogator): else: original_danbooru_tag = tag if original_danbooru_tag == "safe": - continue + original_danbooru_tag = "general" new_lst.append(original_danbooru_tag) return new_lst From d299e68526b0f659abbbc2d76dbe81f9b641d0aa Mon Sep 17 00:00:00 2001 From: Luna Date: Sat, 10 Jun 2023 16:51:50 -0300 Subject: [PATCH 4/4] add plotting for the error rate of taggers --- main.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/main.py b/main.py index 7ef1f73..83e5abe 100644 --- a/main.py +++ b/main.py @@ -411,6 +411,8 @@ async def scores(ctx): PLOTS = Path.cwd() / "plots" PLOTS.mkdir(exist_ok=True) + log.info("plotting score histogram...") + data_for_df = {} data_for_df["scores"] = [] data_for_df["model"] = [] @@ -435,7 +437,10 @@ async def scores(ctx): ) pio.write_image(fig, PLOTS / "score_histogram.png", width=1024, height=800) + log.info("plotting positive histogram...") plot2(PLOTS / "positive_score_histogram.png", normalized_scores, model_scores) + log.info("plotting error rates...") + plot3(PLOTS / "error_rate.png", normalized_scores, model_scores) def plot2(output_path, normalized_scores, model_scores): @@ -466,6 +471,39 @@ def plot2(output_path, normalized_scores, model_scores): pio.write_image(fig, output_path, width=1024, height=800) +def plot3(output_path, normalized_scores, model_scores): + data_for_df = {"model": [], "errors": [], "rating_errors": []} + + for model in sorted( + normalized_scores.keys(), + key=lambda model: normalized_scores[model], + reverse=True, + ): + total_incorrect_tags = 0 + total_rating_errors = 0 + for score_data in model_scores[model].values(): + total_incorrect_tags += len(score_data["incorrect_tags"]) + total_rating_errors += sum( + 1 + for rating in ["general", "sensitive", "questionable", "explicit"] + if rating in score_data["incorrect_tags"] + ) + + data_for_df["errors"].append(total_incorrect_tags) + data_for_df["rating_errors"].append(total_rating_errors) + data_for_df["model"].append(model) + + df = pd.DataFrame(data_for_df) + + fig = go.Figure( + data=[ + go.Bar(name="incorrect tags", x=df.model, y=df.errors), + go.Bar(name="incorrect ratings", x=df.model, y=df.rating_errors), + ] + ) + pio.write_image(fig, output_path, width=1024, height=800) + + async def realmain(ctx): await ctx.db.executescript( """