add ControlInterrogator
because validating that scoring gives 1.0 for the real data is important
This commit is contained in:
parent
f53ae5779a
commit
5a51c67003
1 changed files with 20 additions and 9 deletions
29
main.py
29
main.py
|
@ -67,8 +67,8 @@ class DDInterrogator(Interrogator):
|
|||
new_lst.append(original_danbooru_tag)
|
||||
return new_lst
|
||||
|
||||
async def interrogate(self, session, path):
|
||||
async with session.post(
|
||||
async def interrogate(self, ctx, path):
|
||||
async with ctx.session.post(
|
||||
f"{self.address}/",
|
||||
params={
|
||||
"threshold": "0.7",
|
||||
|
@ -83,11 +83,11 @@ class DDInterrogator(Interrogator):
|
|||
|
||||
|
||||
class SDInterrogator(Interrogator):
|
||||
async def interrogate(self, session, path):
|
||||
async def interrogate(self, ctx, path):
|
||||
async with aiofiles.open(path, "rb") as fd:
|
||||
as_base64 = base64.b64encode(await fd.read()).decode("utf-8")
|
||||
|
||||
async with session.post(
|
||||
async with ctx.session.post(
|
||||
f"{self.address}/tagger/v1/interrogate",
|
||||
json={
|
||||
"model": self.model_id,
|
||||
|
@ -105,6 +105,13 @@ class SDInterrogator(Interrogator):
|
|||
return " ".join(upstream_tags)
|
||||
|
||||
|
||||
class ControlInterrogator(Interrogator):
|
||||
async def interrogate(self, ctx, path):
|
||||
md5_hash = Path(path).stem
|
||||
post = await fetch_post(ctx, md5_hash)
|
||||
return post["tag_string"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
sd_webui_address: str
|
||||
|
@ -114,10 +121,14 @@ class Config:
|
|||
|
||||
@property
|
||||
def all_available_models(self) -> List[Any]:
|
||||
return [
|
||||
SDInterrogator(sd_interrogator, self.sd_webui_address)
|
||||
for sd_interrogator in self.sd_webui_models
|
||||
] + [DDInterrogator(self.dd_model_name, self.dd_address)]
|
||||
return (
|
||||
[
|
||||
SDInterrogator(sd_interrogator, self.sd_webui_address)
|
||||
for sd_interrogator in self.sd_webui_models
|
||||
]
|
||||
+ [DDInterrogator(self.dd_model_name, self.dd_address)]
|
||||
+ [ControlInterrogator("control", None)]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -291,7 +302,7 @@ async def fight(ctx):
|
|||
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
|
||||
|
||||
start_ts = time.monotonic()
|
||||
tag_string = await interrogator.interrogate(ctx.session, post_filepath)
|
||||
tag_string = await interrogator.interrogate(ctx, post_filepath)
|
||||
end_ts = time.monotonic()
|
||||
time_taken = round(end_ts - start_ts, 10)
|
||||
|
||||
|
|
Loading…
Reference in a new issue