add ControlInterrogator

because validating that scoring gives 1.0 for the real data is important
This commit is contained in:
Luna 2023-06-10 03:55:45 -03:00
parent f53ae5779a
commit 5a51c67003
1 changed files with 20 additions and 9 deletions

29
main.py
View File

@ -67,8 +67,8 @@ class DDInterrogator(Interrogator):
new_lst.append(original_danbooru_tag)
return new_lst
async def interrogate(self, session, path):
async with session.post(
async def interrogate(self, ctx, path):
async with ctx.session.post(
f"{self.address}/",
params={
"threshold": "0.7",
@ -83,11 +83,11 @@ class DDInterrogator(Interrogator):
class SDInterrogator(Interrogator):
async def interrogate(self, session, path):
async def interrogate(self, ctx, path):
async with aiofiles.open(path, "rb") as fd:
as_base64 = base64.b64encode(await fd.read()).decode("utf-8")
async with session.post(
async with ctx.session.post(
f"{self.address}/tagger/v1/interrogate",
json={
"model": self.model_id,
@ -105,6 +105,13 @@ class SDInterrogator(Interrogator):
return " ".join(upstream_tags)
class ControlInterrogator(Interrogator):
async def interrogate(self, ctx, path):
md5_hash = Path(path).stem
post = await fetch_post(ctx, md5_hash)
return post["tag_string"]
@dataclass
class Config:
sd_webui_address: str
@ -114,10 +121,14 @@ class Config:
@property
def all_available_models(self) -> List[Any]:
return [
SDInterrogator(sd_interrogator, self.sd_webui_address)
for sd_interrogator in self.sd_webui_models
] + [DDInterrogator(self.dd_model_name, self.dd_address)]
return (
[
SDInterrogator(sd_interrogator, self.sd_webui_address)
for sd_interrogator in self.sd_webui_models
]
+ [DDInterrogator(self.dd_model_name, self.dd_address)]
+ [ControlInterrogator("control", None)]
)
@dataclass
@ -291,7 +302,7 @@ async def fight(ctx):
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
start_ts = time.monotonic()
tag_string = await interrogator.interrogate(ctx.session, post_filepath)
tag_string = await interrogator.interrogate(ctx, post_filepath)
end_ts = time.monotonic()
time_taken = round(end_ts - start_ts, 10)