From dbf458eed7d5231069e8780347741f4b152504fb Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 9 Jun 2023 23:53:04 -0300 Subject: [PATCH] add fight mode between tagger models --- main.py | 104 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 102 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 624143e..922421c 100644 --- a/main.py +++ b/main.py @@ -1,8 +1,10 @@ import json import aiohttp +import aiofiles import sys import asyncio import aiosqlite +import base64 import logging from collections import defaultdict from pathlib import Path @@ -24,6 +26,51 @@ DEFAULTS = [ ] +@dataclass +class Interrogator: + model_id: str + address: str + + +class DDInterrogator(Interrogator): + async def interrogate(self, session, path): + async with session.post( + f"{self.address}/", + params={ + "threshold": "0.7", + }, + headers={"Authorization": "Bearer sex"}, + data={"file": path.open("rb")}, + ) as resp: + assert resp.status == 200 + tags = await resp.json() + upstream_tags = [tag.replace(" ", "_") for tag in tags] + return " ".join(upstream_tags) + + +class SDInterrogator(Interrogator): + async def interrogate(self, session, path): + async with aiofiles.open(path, "rb") as fd: + as_base64 = base64.b64encode(await fd.read()).decode("utf-8") + + async with session.post( + f"{self.address}/tagger/v1/interrogate", + json={ + "model": self.model_id, + "threshold": 0.7, + "image": as_base64, + }, + ) as resp: + log.info("got %d", resp.status) + assert resp.status == 200 + data = await resp.json() + tags_with_scores = data["caption"] + tags = list(tags_with_scores.keys()) + + upstream_tags = [tag.replace(" ", "_") for tag in tags] + return " ".join(upstream_tags) + + @dataclass class Config: sd_webui_address: str @@ -31,6 +78,13 @@ class Config: dd_model_name: str sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS)) + @property + def all_available_models(self) -> List[Any]: + return [ + SDInterrogator(sd_interrogator, self.sd_webui_address) + for sd_interrogator in self.sd_webui_models + ] + [DDInterrogator(self.dd_model_name, self.dd_address)] + @dataclass class Context: @@ -137,14 +191,58 @@ async def insert_post(ctx, post): await ctx.db.commit() +async def insert_interrogated_result( + ctx, interrogator: Interrogator, md5: str, tag_string: str +): + await ctx.db.execute_insert( + "insert into interrogated_posts (md5, model_name, output_tag_string) values (?,?,?)", + (md5, interrogator.model_id, tag_string), + ) + await ctx.db.commit() + + async def fight(ctx): - pass + interrogators = ctx.config.all_available_models + + all_rows = await ctx.db.execute_fetchall("select md5 from posts") + all_hashes = set(r[0] for r in all_rows) + + for interrogator in interrogators: + log.info("processing for %r", interrogator) + # calculate set of images we didn't interrogate yet + interrogated_rows = await ctx.db.execute_fetchall( + "select md5 from interrogated_posts where model_name = ?", + (interrogator.model_id,), + ) + interrogated_hashes = set(row[0] for row in interrogated_rows) + missing_hashes = all_hashes - interrogated_hashes + + log.info("missing %d hashes", len(missing_hashes)) + + for missing_hash in missing_hashes: + log.info("interrogating %r", missing_hash) + post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*")) + tag_string = await interrogator.interrogate(ctx.session, post_filepath) + log.info("got %r", tag_string) + await insert_interrogated_result( + ctx, interrogator, missing_hash, tag_string + ) async def realmain(ctx): await ctx.db.executescript( """ - create table if not exists posts (md5 text primary key, filepath text, data text); + create table if not exists posts ( + md5 text primary key, + filepath text, + data text + ); + create table if not exists interrogated_posts ( + md5 text, + model_name text not null, + output_tag_string text not null, + primary key (md5, model_name) + ); """ ) @@ -157,6 +255,8 @@ async def realmain(ctx): await download_images(ctx) elif mode == "fight": await fight(ctx) + elif mode == "scores": + await scores(ctx) else: raise AssertionError(f"invalid mode {mode}")