import json import aiohttp import aiofiles import sys import asyncio import aiosqlite import base64 import logging from collections import defaultdict from pathlib import Path from urllib.parse import urlparse from typing import Any, Optional, List from aiolimiter import AsyncLimiter from dataclasses import dataclass, field log = logging.getLogger(__name__) DEFAULTS = [ "wd14-vit-v2-git", "wd14-vit", "wd14-swinv2-v2-git", "wd14-convnextv2-v2-git", "wd14-convnext-v2-git", "wd14-convnext", ] @dataclass class Interrogator: model_id: str address: str class DDInterrogator(Interrogator): async def interrogate(self, session, path): async with session.post( f"{self.address}/", params={ "threshold": "0.7", }, headers={"Authorization": "Bearer sex"}, data={"file": path.open("rb")}, ) as resp: assert resp.status == 200 tags = await resp.json() upstream_tags = [tag.replace(" ", "_") for tag in tags] return " ".join(upstream_tags) class SDInterrogator(Interrogator): async def interrogate(self, session, path): async with aiofiles.open(path, "rb") as fd: as_base64 = base64.b64encode(await fd.read()).decode("utf-8") async with session.post( f"{self.address}/tagger/v1/interrogate", json={ "model": self.model_id, "threshold": 0.7, "image": as_base64, }, ) as resp: log.info("got %d", resp.status) assert resp.status == 200 data = await resp.json() tags_with_scores = data["caption"] tags = list(tags_with_scores.keys()) upstream_tags = [tag.replace(" ", "_") for tag in tags] return " ".join(upstream_tags) @dataclass class Config: sd_webui_address: str dd_address: str dd_model_name: str sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS)) @property def all_available_models(self) -> List[Any]: return [ SDInterrogator(sd_interrogator, self.sd_webui_address) for sd_interrogator in self.sd_webui_models ] + [DDInterrogator(self.dd_model_name, self.dd_address)] @dataclass class Context: db: Any session: aiohttp.ClientSession config: Config @dataclass class Booru: session: aiohttp.ClientSession limiter: AsyncLimiter tag_limiter: AsyncLimiter file_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock)) tag_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock)) fetch_type = "hash" @property def hash_style(self): return HashStyle.md5 class Danbooru(Booru): title = "Danbooru" base_url = "https://danbooru.donmai.us" async def posts(self, tag_query: str, limit): log.info("%s: submit %r", self.title, tag_query) async with self.limiter: log.info("%s: submit upstream %r", self.title, tag_query) async with self.session.get( f"{self.base_url}/posts.json", params={"tags": tag_query, "limit": limit}, ) as resp: assert resp.status == 200 rjson = await resp.json() return rjson DOWNLOADS = Path.cwd() / "posts" DOWNLOADS.mkdir(exist_ok=True) async def download_images(ctx): try: tagquery = sys.argv[2] except IndexError: tagquery = "" try: limit = int(sys.argv[3]) except IndexError: limit = 10 danbooru = Danbooru( ctx.session, AsyncLimiter(1, 10), AsyncLimiter(1, 3), ) posts = await danbooru.posts(tagquery, limit) for post in posts: if "md5" not in post: continue log.info("processing post %r", post) existing_post = await fetch_post(ctx, post["md5"]) if existing_post: continue # download the post post_file_url = post["file_url"] post_filename = Path(urlparse(post_file_url).path).name post_filepath = DOWNLOADS / post_filename if not post_filepath.exists(): log.info("downloading %r to %r", post_file_url, post_filepath) async with ctx.session.get(post_file_url) as resp: assert resp.status == 200 with post_filepath.open("wb") as fd: async for chunk in resp.content.iter_chunked(1024): fd.write(chunk) # when it's done successfully, insert it await insert_post(ctx, post) async def fetch_post(ctx, md5) -> Optional[dict]: rows = await ctx.db.execute_fetchall("select data from posts where md5 = ?", (md5,)) if not rows: return None assert len(rows) == 1 return json.loads(rows[0][0]) async def insert_post(ctx, post): await ctx.db.execute_insert( "insert into posts (md5, filepath ,data) values (?, ?,?)", ( post["md5"], Path(urlparse(post["file_url"]).path).name, json.dumps(post), ), ) await ctx.db.commit() async def insert_interrogated_result( ctx, interrogator: Interrogator, md5: str, tag_string: str ): await ctx.db.execute_insert( "insert into interrogated_posts (md5, model_name, output_tag_string) values (?,?,?)", (md5, interrogator.model_id, tag_string), ) await ctx.db.commit() async def fight(ctx): interrogators = ctx.config.all_available_models all_rows = await ctx.db.execute_fetchall("select md5 from posts") all_hashes = set(r[0] for r in all_rows) for interrogator in interrogators: log.info("processing for %r", interrogator) # calculate set of images we didn't interrogate yet interrogated_rows = await ctx.db.execute_fetchall( "select md5 from interrogated_posts where model_name = ?", (interrogator.model_id,), ) interrogated_hashes = set(row[0] for row in interrogated_rows) missing_hashes = all_hashes - interrogated_hashes log.info("missing %d hashes", len(missing_hashes)) for missing_hash in missing_hashes: log.info("interrogating %r", missing_hash) post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*")) tag_string = await interrogator.interrogate(ctx.session, post_filepath) log.info("got %r", tag_string) await insert_interrogated_result( ctx, interrogator, missing_hash, tag_string ) async def realmain(ctx): await ctx.db.executescript( """ create table if not exists posts ( md5 text primary key, filepath text, data text ); create table if not exists interrogated_posts ( md5 text, model_name text not null, output_tag_string text not null, primary key (md5, model_name) ); """ ) try: mode = sys.argv[1] except IndexError: raise Exception("must have mode") if mode == "download_images": await download_images(ctx) elif mode == "fight": await fight(ctx) elif mode == "scores": await scores(ctx) else: raise AssertionError(f"invalid mode {mode}") async def main(): CONFIG_PATH = Path.cwd() / "config.json" log.info("hewo") async with aiosqlite.connect("./data.db") as db: async with aiohttp.ClientSession() as session: with CONFIG_PATH.open() as config_fd: config_json = json.load(config_fd) config = Config(**config_json) ctx = Context(db, session, config) await realmain(ctx) if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.run(main())