add fight mode between tagger models
This commit is contained in:
parent
6d113e77b8
commit
dbf458eed7
1 changed files with 102 additions and 2 deletions
104
main.py
104
main.py
|
@ -1,8 +1,10 @@
|
|||
import json
|
||||
import aiohttp
|
||||
import aiofiles
|
||||
import sys
|
||||
import asyncio
|
||||
import aiosqlite
|
||||
import base64
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
@ -24,6 +26,51 @@ DEFAULTS = [
|
|||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Interrogator:
|
||||
model_id: str
|
||||
address: str
|
||||
|
||||
|
||||
class DDInterrogator(Interrogator):
|
||||
async def interrogate(self, session, path):
|
||||
async with session.post(
|
||||
f"{self.address}/",
|
||||
params={
|
||||
"threshold": "0.7",
|
||||
},
|
||||
headers={"Authorization": "Bearer sex"},
|
||||
data={"file": path.open("rb")},
|
||||
) as resp:
|
||||
assert resp.status == 200
|
||||
tags = await resp.json()
|
||||
upstream_tags = [tag.replace(" ", "_") for tag in tags]
|
||||
return " ".join(upstream_tags)
|
||||
|
||||
|
||||
class SDInterrogator(Interrogator):
|
||||
async def interrogate(self, session, path):
|
||||
async with aiofiles.open(path, "rb") as fd:
|
||||
as_base64 = base64.b64encode(await fd.read()).decode("utf-8")
|
||||
|
||||
async with session.post(
|
||||
f"{self.address}/tagger/v1/interrogate",
|
||||
json={
|
||||
"model": self.model_id,
|
||||
"threshold": 0.7,
|
||||
"image": as_base64,
|
||||
},
|
||||
) as resp:
|
||||
log.info("got %d", resp.status)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
tags_with_scores = data["caption"]
|
||||
tags = list(tags_with_scores.keys())
|
||||
|
||||
upstream_tags = [tag.replace(" ", "_") for tag in tags]
|
||||
return " ".join(upstream_tags)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
sd_webui_address: str
|
||||
|
@ -31,6 +78,13 @@ class Config:
|
|||
dd_model_name: str
|
||||
sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS))
|
||||
|
||||
@property
|
||||
def all_available_models(self) -> List[Any]:
|
||||
return [
|
||||
SDInterrogator(sd_interrogator, self.sd_webui_address)
|
||||
for sd_interrogator in self.sd_webui_models
|
||||
] + [DDInterrogator(self.dd_model_name, self.dd_address)]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
|
@ -137,14 +191,58 @@ async def insert_post(ctx, post):
|
|||
await ctx.db.commit()
|
||||
|
||||
|
||||
async def insert_interrogated_result(
|
||||
ctx, interrogator: Interrogator, md5: str, tag_string: str
|
||||
):
|
||||
await ctx.db.execute_insert(
|
||||
"insert into interrogated_posts (md5, model_name, output_tag_string) values (?,?,?)",
|
||||
(md5, interrogator.model_id, tag_string),
|
||||
)
|
||||
await ctx.db.commit()
|
||||
|
||||
|
||||
async def fight(ctx):
|
||||
pass
|
||||
interrogators = ctx.config.all_available_models
|
||||
|
||||
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
|
||||
all_hashes = set(r[0] for r in all_rows)
|
||||
|
||||
for interrogator in interrogators:
|
||||
log.info("processing for %r", interrogator)
|
||||
# calculate set of images we didn't interrogate yet
|
||||
interrogated_rows = await ctx.db.execute_fetchall(
|
||||
"select md5 from interrogated_posts where model_name = ?",
|
||||
(interrogator.model_id,),
|
||||
)
|
||||
interrogated_hashes = set(row[0] for row in interrogated_rows)
|
||||
missing_hashes = all_hashes - interrogated_hashes
|
||||
|
||||
log.info("missing %d hashes", len(missing_hashes))
|
||||
|
||||
for missing_hash in missing_hashes:
|
||||
log.info("interrogating %r", missing_hash)
|
||||
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
|
||||
tag_string = await interrogator.interrogate(ctx.session, post_filepath)
|
||||
log.info("got %r", tag_string)
|
||||
await insert_interrogated_result(
|
||||
ctx, interrogator, missing_hash, tag_string
|
||||
)
|
||||
|
||||
|
||||
async def realmain(ctx):
|
||||
await ctx.db.executescript(
|
||||
"""
|
||||
create table if not exists posts (md5 text primary key, filepath text, data text);
|
||||
create table if not exists posts (
|
||||
md5 text primary key,
|
||||
filepath text,
|
||||
data text
|
||||
);
|
||||
create table if not exists interrogated_posts (
|
||||
md5 text,
|
||||
model_name text not null,
|
||||
output_tag_string text not null,
|
||||
primary key (md5, model_name)
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
|
@ -157,6 +255,8 @@ async def realmain(ctx):
|
|||
await download_images(ctx)
|
||||
elif mode == "fight":
|
||||
await fight(ctx)
|
||||
elif mode == "scores":
|
||||
await scores(ctx)
|
||||
else:
|
||||
raise AssertionError(f"invalid mode {mode}")
|
||||
|
||||
|
|
Loading…
Reference in a new issue