add fight mode between tagger models

This commit is contained in:
Luna 2023-06-09 23:53:04 -03:00
parent 6d113e77b8
commit dbf458eed7

104
main.py
View file

@ -1,8 +1,10 @@
import json
import aiohttp
import aiofiles
import sys
import asyncio
import aiosqlite
import base64
import logging
from collections import defaultdict
from pathlib import Path
@ -24,6 +26,51 @@ DEFAULTS = [
]
@dataclass
class Interrogator:
model_id: str
address: str
class DDInterrogator(Interrogator):
async def interrogate(self, session, path):
async with session.post(
f"{self.address}/",
params={
"threshold": "0.7",
},
headers={"Authorization": "Bearer sex"},
data={"file": path.open("rb")},
) as resp:
assert resp.status == 200
tags = await resp.json()
upstream_tags = [tag.replace(" ", "_") for tag in tags]
return " ".join(upstream_tags)
class SDInterrogator(Interrogator):
async def interrogate(self, session, path):
async with aiofiles.open(path, "rb") as fd:
as_base64 = base64.b64encode(await fd.read()).decode("utf-8")
async with session.post(
f"{self.address}/tagger/v1/interrogate",
json={
"model": self.model_id,
"threshold": 0.7,
"image": as_base64,
},
) as resp:
log.info("got %d", resp.status)
assert resp.status == 200
data = await resp.json()
tags_with_scores = data["caption"]
tags = list(tags_with_scores.keys())
upstream_tags = [tag.replace(" ", "_") for tag in tags]
return " ".join(upstream_tags)
@dataclass
class Config:
sd_webui_address: str
@ -31,6 +78,13 @@ class Config:
dd_model_name: str
sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS))
@property
def all_available_models(self) -> List[Any]:
return [
SDInterrogator(sd_interrogator, self.sd_webui_address)
for sd_interrogator in self.sd_webui_models
] + [DDInterrogator(self.dd_model_name, self.dd_address)]
@dataclass
class Context:
@ -137,14 +191,58 @@ async def insert_post(ctx, post):
await ctx.db.commit()
async def insert_interrogated_result(
ctx, interrogator: Interrogator, md5: str, tag_string: str
):
await ctx.db.execute_insert(
"insert into interrogated_posts (md5, model_name, output_tag_string) values (?,?,?)",
(md5, interrogator.model_id, tag_string),
)
await ctx.db.commit()
async def fight(ctx):
pass
interrogators = ctx.config.all_available_models
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
all_hashes = set(r[0] for r in all_rows)
for interrogator in interrogators:
log.info("processing for %r", interrogator)
# calculate set of images we didn't interrogate yet
interrogated_rows = await ctx.db.execute_fetchall(
"select md5 from interrogated_posts where model_name = ?",
(interrogator.model_id,),
)
interrogated_hashes = set(row[0] for row in interrogated_rows)
missing_hashes = all_hashes - interrogated_hashes
log.info("missing %d hashes", len(missing_hashes))
for missing_hash in missing_hashes:
log.info("interrogating %r", missing_hash)
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
tag_string = await interrogator.interrogate(ctx.session, post_filepath)
log.info("got %r", tag_string)
await insert_interrogated_result(
ctx, interrogator, missing_hash, tag_string
)
async def realmain(ctx):
await ctx.db.executescript(
"""
create table if not exists posts (md5 text primary key, filepath text, data text);
create table if not exists posts (
md5 text primary key,
filepath text,
data text
);
create table if not exists interrogated_posts (
md5 text,
model_name text not null,
output_tag_string text not null,
primary key (md5, model_name)
);
"""
)
@ -157,6 +255,8 @@ async def realmain(ctx):
await download_images(ctx)
elif mode == "fight":
await fight(ctx)
elif mode == "scores":
await scores(ctx)
else:
raise AssertionError(f"invalid mode {mode}")