import json import aiohttp import pprint import decimal import aiofiles import sys import asyncio import aiosqlite import base64 import logging from collections import defaultdict from pathlib import Path from urllib.parse import urlparse from typing import Any, Optional, List, Set from aiolimiter import AsyncLimiter from dataclasses import dataclass, field log = logging.getLogger(__name__) DEFAULTS = [ "wd14-vit-v2-git", "wd14-vit", "wd14-swinv2-v2-git", "wd14-convnextv2-v2-git", "wd14-convnext-v2-git", "wd14-convnext", ] @dataclass class Interrogator: model_id: str address: str class DDInterrogator(Interrogator): async def interrogate(self, session, path): async with session.post( f"{self.address}/", params={ "threshold": "0.7", }, headers={"Authorization": "Bearer sex"}, data={"file": path.open("rb")}, ) as resp: assert resp.status == 200 tags = await resp.json() upstream_tags = [tag.replace(" ", "_") for tag in tags] return " ".join(upstream_tags) class SDInterrogator(Interrogator): async def interrogate(self, session, path): async with aiofiles.open(path, "rb") as fd: as_base64 = base64.b64encode(await fd.read()).decode("utf-8") async with session.post( f"{self.address}/tagger/v1/interrogate", json={ "model": self.model_id, "threshold": 0.7, "image": as_base64, }, ) as resp: log.info("got %d", resp.status) assert resp.status == 200 data = await resp.json() tags_with_scores = data["caption"] tags = list(tags_with_scores.keys()) upstream_tags = [tag.replace(" ", "_") for tag in tags] return " ".join(upstream_tags) @dataclass class Config: sd_webui_address: str dd_address: str dd_model_name: str sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS)) @property def all_available_models(self) -> List[Any]: return [ SDInterrogator(sd_interrogator, self.sd_webui_address) for sd_interrogator in self.sd_webui_models ] + [DDInterrogator(self.dd_model_name, self.dd_address)] @dataclass class Context: db: Any session: aiohttp.ClientSession config: Config @dataclass class Booru: session: aiohttp.ClientSession limiter: AsyncLimiter tag_limiter: AsyncLimiter file_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock)) tag_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock)) fetch_type = "hash" @property def hash_style(self): return HashStyle.md5 class Danbooru(Booru): title = "Danbooru" base_url = "https://danbooru.donmai.us" async def posts(self, tag_query: str, limit): log.info("%s: submit %r", self.title, tag_query) async with self.limiter: log.info("%s: submit upstream %r", self.title, tag_query) async with self.session.get( f"{self.base_url}/posts.json", params={"tags": tag_query, "limit": limit}, ) as resp: assert resp.status == 200 rjson = await resp.json() return rjson DOWNLOADS = Path.cwd() / "posts" DOWNLOADS.mkdir(exist_ok=True) ALLOWED_EXTENSIONS = ( "jpg", "jpeg", "png", ) async def download_images(ctx): try: tagquery = sys.argv[2] except IndexError: tagquery = "" try: limit = int(sys.argv[3]) except IndexError: limit = 10 danbooru = Danbooru( ctx.session, AsyncLimiter(1, 10), AsyncLimiter(1, 3), ) posts = await danbooru.posts(tagquery, limit) for post in posts: if "md5" not in post: continue md5 = post["md5"] log.info("processing post %r", md5) existing_post = await fetch_post(ctx, md5) if existing_post: continue # download the post post_file_url = post["file_url"] post_url_path = Path(urlparse(post_file_url).path) good_extension = False for extension in ALLOWED_EXTENSIONS: if extension in post_url_path.stem: good_extension = True if not good_extension: log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.stem) continue post_filename = post_url_path.name post_filepath = DOWNLOADS / post_filename if not post_filepath.exists(): log.info("downloading %r to %r", post_file_url, post_filepath) async with ctx.session.get(post_file_url) as resp: assert resp.status == 200 with post_filepath.open("wb") as fd: async for chunk in resp.content.iter_chunked(1024): fd.write(chunk) # when it's done successfully, insert it await insert_post(ctx, post) async def fetch_post(ctx, md5) -> Optional[dict]: rows = await ctx.db.execute_fetchall("select data from posts where md5 = ?", (md5,)) if not rows: return None assert len(rows) == 1 return json.loads(rows[0][0]) async def insert_post(ctx, post): await ctx.db.execute_insert( "insert into posts (md5, filepath ,data) values (?, ?,?)", ( post["md5"], Path(urlparse(post["file_url"]).path).name, json.dumps(post), ), ) await ctx.db.commit() async def insert_interrogated_result( ctx, interrogator: Interrogator, md5: str, tag_string: str ): await ctx.db.execute_insert( "insert into interrogated_posts (md5, model_name, output_tag_string) values (?,?,?)", (md5, interrogator.model_id, tag_string), ) await ctx.db.commit() async def fight(ctx): interrogators = ctx.config.all_available_models all_rows = await ctx.db.execute_fetchall("select md5 from posts") all_hashes = set(r[0] for r in all_rows) for interrogator in interrogators: log.info("processing for %r", interrogator) # calculate set of images we didn't interrogate yet interrogated_rows = await ctx.db.execute_fetchall( "select md5 from interrogated_posts where model_name = ?", (interrogator.model_id,), ) interrogated_hashes = set(row[0] for row in interrogated_rows) missing_hashes = all_hashes - interrogated_hashes log.info("missing %d hashes", len(missing_hashes)) for index, missing_hash in enumerate(missing_hashes): log.info( "interrogating %r (%d/%d)", missing_hash, index, len(missing_hashes) ) post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*")) tag_string = await interrogator.interrogate(ctx.session, post_filepath) log.info("got %r", tag_string) await insert_interrogated_result( ctx, interrogator, missing_hash, tag_string ) def score(danbooru_tags: Set[str], interrogator_tags: Set[str]) -> decimal.Decimal: tags_in_danbooru = danbooru_tags | interrogator_tags tags_not_in_danbooru = danbooru_tags - interrogator_tags return decimal.Decimal( len(tags_in_danbooru) - len(tags_not_in_danbooru) ) / decimal.Decimal(len(danbooru_tags)) async def scores(ctx): interrogators = ctx.config.all_available_models all_rows = await ctx.db.execute_fetchall("select md5 from posts") all_hashes = set(r[0] for r in all_rows) absolute_scores = defaultdict(decimal.Decimal) for md5_hash in all_hashes: log.info("processing for %r", md5_hash) post = await fetch_post(ctx, md5_hash) danbooru_tags = set(post["tag_string"].split()) for interrogator in interrogators: rows = await ctx.db.execute_fetchall( "select output_tag_string from interrogated_posts where md5 = ? and model_name = ?", (md5_hash, interrogator.model_id), ) assert len(rows) == 1 # run 'fight' mode, there's missing posts tag_string = rows[0][0] interrogator_tags = set(tag_string.split()) absolute_scores[interrogator.model_id] += score( danbooru_tags, interrogator_tags ) normalized_scores = { model: absolute_scores[model] / len(all_hashes) for model in absolute_scores } for model in sorted( normalized_scores.keys(), key=lambda model: normalized_scores[model], reverse=True, ): print(model, normalized_scores[model]) async def realmain(ctx): await ctx.db.executescript( """ create table if not exists posts ( md5 text primary key, filepath text, data text ); create table if not exists interrogated_posts ( md5 text, model_name text not null, output_tag_string text not null, primary key (md5, model_name) ); """ ) try: mode = sys.argv[1] except IndexError: raise Exception("must have mode") if mode == "download_images": await download_images(ctx) elif mode == "fight": await fight(ctx) elif mode == "scores": await scores(ctx) else: raise AssertionError(f"invalid mode {mode}") async def main(): CONFIG_PATH = Path.cwd() / "config.json" log.info("hewo") async with aiosqlite.connect("./data.db") as db: async with aiohttp.ClientSession() as session: with CONFIG_PATH.open() as config_fd: config_json = json.load(config_fd) config = Config(**config_json) ctx = Context(db, session, config) await realmain(ctx) if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.run(main())