import time import os import json import aiohttp import pprint import decimal import aiofiles import sys import asyncio import aiosqlite import base64 import logging import pandas as pd import plotly.express as px import plotly.graph_objs as go import plotly.io as pio from collections import defaultdict, Counter from pathlib import Path from urllib.parse import urlparse from typing import Any, Optional, List, Set, Tuple from aiolimiter import AsyncLimiter from dataclasses import dataclass, field log = logging.getLogger(__name__) DEFAULTS = [ "wd14-vit-v2-git", "wd14-vit", "wd14-swinv2-v2-git", "wd14-convnextv2-v2-git", "wd14-convnext-v2-git", "wd14-convnext", ] @dataclass class InterrogatorPost: tags: List[str] time_taken: float @dataclass class Interrogator: model_id: str address: str def _process(self, lst): return lst async def fetch(self, ctx, md5_hash): rows = await ctx.db.execute_fetchall( "select output_tag_string, time_taken from interrogated_posts where md5 = ? and model_name = ?", (md5_hash, self.model_id), ) assert len(rows) == 1 # run 'fight' mode, there's missing posts tag_string, time_taken = rows[0][0], rows[0][1] return InterrogatorPost(self._process(tag_string.split()), time_taken) class DDInterrogator(Interrogator): def _process(self, lst): new_lst = [] for tag in lst: if tag.startswith("rating:"): original_danbooru_tag = tag.split(":")[1] else: original_danbooru_tag = tag if original_danbooru_tag == "safe": original_danbooru_tag = "general" new_lst.append(original_danbooru_tag) return new_lst async def interrogate(self, ctx, path): async with ctx.session.post( f"{self.address}/", params={ "threshold": "0.7", }, headers={"Authorization": "Bearer sex"}, data={"file": path.open("rb")}, ) as resp: assert resp.status == 200 tags = await resp.json() upstream_tags = [tag.replace(" ", "_") for tag in tags] return " ".join(upstream_tags) class SDInterrogator(Interrogator): async def interrogate(self, ctx, path): async with aiofiles.open(path, "rb") as fd: as_base64 = base64.b64encode(await fd.read()).decode("utf-8") async with ctx.session.post( f"{self.address}/tagger/v1/interrogate", json={ "model": self.model_id, "threshold": 0.7, "image": as_base64, }, ) as resp: log.info("got %d", resp.status) assert resp.status == 200 data = await resp.json() tags_with_scores = data["caption"] tags = list(tags_with_scores.keys()) upstream_tags = [tag.replace(" ", "_") for tag in tags] return " ".join(upstream_tags) class ControlInterrogator(Interrogator): async def interrogate(self, ctx, path): md5_hash = Path(path).stem post = await fetch_post(ctx, md5_hash) return post["tag_string"] @dataclass class Config: sd_webui_address: str dd_address: str dd_model_name: str sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS)) @property def all_available_models(self) -> List[Any]: return ( [ SDInterrogator(sd_interrogator, self.sd_webui_address) for sd_interrogator in self.sd_webui_models ] + [DDInterrogator(self.dd_model_name, self.dd_address)] + [ControlInterrogator("control", None)] ) @dataclass class Context: db: Any session: aiohttp.ClientSession config: Config @dataclass class Booru: session: aiohttp.ClientSession limiter: AsyncLimiter tag_limiter: AsyncLimiter file_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock)) tag_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock)) fetch_type = "hash" @property def hash_style(self): return HashStyle.md5 class Danbooru(Booru): title = "Danbooru" base_url = "https://danbooru.donmai.us" async def posts(self, tag_query: str, limit, page: int): log.info("%s: submit %r", self.title, tag_query) async with self.limiter: log.info("%s: submit upstream %r", self.title, tag_query) async with self.session.get( f"{self.base_url}/posts.json", params={"tags": tag_query, "limit": limit, "page": page}, ) as resp: assert resp.status == 200 rjson = await resp.json() return rjson DOWNLOADS = Path.cwd() / "posts" DOWNLOADS.mkdir(exist_ok=True) ALLOWED_EXTENSIONS = ( "jpg", "jpeg", "png", ) async def download_images(ctx): try: tagquery = sys.argv[2] except IndexError: tagquery = "" try: limit = int(sys.argv[3]) except IndexError: limit = 30 try: pageskip = int(sys.argv[4]) except IndexError: pageskip = 150 danbooru = Danbooru( ctx.session, AsyncLimiter(1, 10), AsyncLimiter(1, 3), ) posts = await danbooru.posts(tagquery, limit, pageskip) for post in posts: if "md5" not in post: continue md5 = post["md5"] log.info("processing post %r", md5) existing_post = await fetch_post(ctx, md5) if existing_post: log.info("already exists %r", md5) continue # download the post post_file_url = post["file_url"] post_url_path = Path(urlparse(post_file_url).path) good_extension = False for extension in ALLOWED_EXTENSIONS: if extension in post_url_path.suffix: good_extension = True if not good_extension: log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.suffix) continue post_filename = post_url_path.name post_filepath = DOWNLOADS / post_filename if not post_filepath.exists(): log.info("downloading %r to %r", post_file_url, post_filepath) async with ctx.session.get(post_file_url) as resp: assert resp.status == 200 with post_filepath.open("wb") as fd: async for chunk in resp.content.iter_chunked(1024): fd.write(chunk) # when it's done successfully, insert it await insert_post(ctx, post) async def fetch_post(ctx, md5) -> Optional[dict]: rows = await ctx.db.execute_fetchall("select data from posts where md5 = ?", (md5,)) if not rows: return None assert len(rows) == 1 post = json.loads(rows[0][0]) post_rating = post["rating"] match post_rating: case "g": rating_tag = "general" case "s": rating_tag = "sensitive" case "q": rating_tag = "questionable" case "e": rating_tag = "explicit" case _: raise AssertionError("invalid post rating {post_rating!r}") post["tag_string"] = post["tag_string"] + " " + rating_tag return post async def insert_post(ctx, post): await ctx.db.execute_insert( "insert into posts (md5, filepath ,data) values (?, ?,?)", ( post["md5"], Path(urlparse(post["file_url"]).path).name, json.dumps(post), ), ) await ctx.db.commit() async def insert_interrogated_result( ctx, interrogator: Interrogator, md5: str, tag_string: str, time_taken: float ): await ctx.db.execute_insert( "insert into interrogated_posts (md5, model_name, output_tag_string, time_taken) values (?,?,?,?)", (md5, interrogator.model_id, tag_string, time_taken), ) await ctx.db.commit() async def fight(ctx): interrogators = ctx.config.all_available_models all_rows = await ctx.db.execute_fetchall("select md5 from posts") all_hashes = set(r[0] for r in all_rows) for interrogator in interrogators: log.info("processing for %r", interrogator) # calculate set of images we didn't interrogate yet interrogated_rows = await ctx.db.execute_fetchall( "select md5 from interrogated_posts where model_name = ?", (interrogator.model_id,), ) interrogated_hashes = set(row[0] for row in interrogated_rows) missing_hashes = all_hashes - interrogated_hashes log.info("missing %d hashes", len(missing_hashes)) for index, missing_hash in enumerate(missing_hashes): log.info( "interrogating %r (%d/%d)", missing_hash, index, len(missing_hashes) ) post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*")) start_ts = time.monotonic() tag_string = await interrogator.interrogate(ctx, post_filepath) end_ts = time.monotonic() time_taken = round(end_ts - start_ts, 10) log.info("took %.5fsec, got %r", time_taken, tag_string) await insert_interrogated_result( ctx, interrogator, missing_hash, tag_string, time_taken ) def score( danbooru_tags: Set[str], interrogator_tags: Set[str] ) -> Tuple[decimal.Decimal, Set[str]]: tags_in_danbooru = danbooru_tags.intersection(interrogator_tags) tags_not_in_danbooru = interrogator_tags - danbooru_tags return ( round( decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru)) / decimal.Decimal(len(danbooru_tags)), 10, ), tags_not_in_danbooru, ) async def scores(ctx): interrogators = ctx.config.all_available_models all_rows = await ctx.db.execute_fetchall("select md5 from posts") all_hashes = set(r[0] for r in all_rows) # absolute_scores = defaultdict(decimal.Decimal) model_scores = defaultdict(dict) runtimes = defaultdict(list) incorrect_tags_counters = defaultdict(Counter) for md5_hash in all_hashes: log.info("processing for %r", md5_hash) post = await fetch_post(ctx, md5_hash) danbooru_tags = set(post["tag_string"].split()) for interrogator in interrogators: post_data = await interrogator.fetch(ctx, md5_hash) runtimes[interrogator.model_id].append(post_data.time_taken) interrogator_tags = set(post_data.tags) tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags) for tag in incorrect_tags: incorrect_tags_counters[interrogator.model_id][tag] += 1 model_scores[interrogator.model_id][md5_hash] = { "score": tagging_score, "incorrect_tags": incorrect_tags, } summed_scores = { model_id: sum(d["score"] for d in post_scores.values()) for model_id, post_scores in model_scores.items() } normalized_scores = { model: round(summed_scores[model] / len(all_hashes), 10) for model in summed_scores } print("scores are [worst...best]") for model in sorted( normalized_scores.keys(), key=lambda model: normalized_scores[model], reverse=True, ): average_runtime = sum(runtimes[model]) / len(runtimes[model]) print(model, normalized_scores[model], "runtime", average_runtime, "sec") if os.getenv("SHOWOFF", "0") == "1": print("[", end="") for bad_md5_hash in sorted( model_scores[model].keys(), key=lambda md5_hash: model_scores[model][md5_hash]["score"], )[:4]: data = model_scores[model][bad_md5_hash] if os.getenv("DEBUG", "0") == "1": print(md5_hash, data["score"], " ".join(data["incorrect_tags"])) else: print(data["score"], end=",") print("...", end="") for good_md5_hash in sorted( model_scores[model].keys(), key=lambda md5_hash: model_scores[model][md5_hash]["score"], reverse=True, )[:4]: data = model_scores[model][good_md5_hash] print(data["score"], end=",") print("]") print("most incorrect tags", incorrect_tags_counters[model].most_common(5)) PLOTS = Path.cwd() / "plots" PLOTS.mkdir(exist_ok=True) log.info("plotting score histogram...") data_for_df = {} data_for_df["scores"] = [] data_for_df["model"] = [] for model in sorted( normalized_scores.keys(), key=lambda model: normalized_scores[model], reverse=True, ): for post_score in (d["score"] for d in model_scores[model].values()): data_for_df["scores"].append(post_score) data_for_df["model"].append(model) df = pd.DataFrame(data_for_df) fig = px.histogram( df, x="scores", color="model", histfunc="count", marginal="rug", histnorm="probability", ) pio.write_image(fig, PLOTS / "score_histogram.png", width=1024, height=800) log.info("plotting positive histogram...") plot2(PLOTS / "positive_score_histogram.png", normalized_scores, model_scores) log.info("plotting error rates...") plot3(PLOTS / "error_rate.png", normalized_scores, model_scores) def plot2(output_path, normalized_scores, model_scores): data_for_df = {} data_for_df["scores"] = [] data_for_df["model"] = [] for model in sorted( normalized_scores.keys(), key=lambda model: normalized_scores[model], reverse=True, ): for post_score in (d["score"] for d in model_scores[model].values()): if post_score < 0: continue data_for_df["scores"].append(post_score) data_for_df["model"].append(model) df = pd.DataFrame(data_for_df) fig = px.histogram( df, x="scores", color="model", histfunc="count", marginal="rug", histnorm="probability", ) pio.write_image(fig, output_path, width=1024, height=800) def plot3(output_path, normalized_scores, model_scores): data_for_df = {"model": [], "errors": [], "rating_errors": []} for model in sorted( normalized_scores.keys(), key=lambda model: normalized_scores[model], reverse=True, ): total_incorrect_tags = 0 total_rating_errors = 0 for score_data in model_scores[model].values(): total_incorrect_tags += len(score_data["incorrect_tags"]) total_rating_errors += sum( 1 for rating in ["general", "sensitive", "questionable", "explicit"] if rating in score_data["incorrect_tags"] ) data_for_df["errors"].append(total_incorrect_tags) data_for_df["rating_errors"].append(total_rating_errors) data_for_df["model"].append(model) df = pd.DataFrame(data_for_df) fig = go.Figure( data=[ go.Bar(name="incorrect tags", x=df.model, y=df.errors), go.Bar(name="incorrect ratings", x=df.model, y=df.rating_errors), ] ) pio.write_image(fig, output_path, width=1024, height=800) async def realmain(ctx): await ctx.db.executescript( """ create table if not exists posts ( md5 text primary key, filepath text, data text ); create table if not exists interrogated_posts ( md5 text, model_name text not null, output_tag_string text not null, time_taken real not null, primary key (md5, model_name) ); """ ) try: mode = sys.argv[1] except IndexError: raise Exception("must have mode") if mode == "download_images": await download_images(ctx) elif mode == "fight": await fight(ctx) elif mode == "scores": await scores(ctx) else: raise AssertionError(f"invalid mode {mode}") async def main(): CONFIG_PATH = Path.cwd() / "config.json" log.info("hewo") async with aiosqlite.connect("./data.db") as db: async with aiohttp.ClientSession() as session: with CONFIG_PATH.open() as config_fd: config_json = json.load(config_fd) config = Config(**config_json) ctx = Context(db, session, config) await realmain(ctx) if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.run(main())