add ControlInterrogator
because validating that scoring gives 1.0 for the real data is important
This commit is contained in:
parent
f53ae5779a
commit
5a51c67003
29
main.py
29
main.py
|
@ -67,8 +67,8 @@ class DDInterrogator(Interrogator):
|
||||||
new_lst.append(original_danbooru_tag)
|
new_lst.append(original_danbooru_tag)
|
||||||
return new_lst
|
return new_lst
|
||||||
|
|
||||||
async def interrogate(self, session, path):
|
async def interrogate(self, ctx, path):
|
||||||
async with session.post(
|
async with ctx.session.post(
|
||||||
f"{self.address}/",
|
f"{self.address}/",
|
||||||
params={
|
params={
|
||||||
"threshold": "0.7",
|
"threshold": "0.7",
|
||||||
|
@ -83,11 +83,11 @@ class DDInterrogator(Interrogator):
|
||||||
|
|
||||||
|
|
||||||
class SDInterrogator(Interrogator):
|
class SDInterrogator(Interrogator):
|
||||||
async def interrogate(self, session, path):
|
async def interrogate(self, ctx, path):
|
||||||
async with aiofiles.open(path, "rb") as fd:
|
async with aiofiles.open(path, "rb") as fd:
|
||||||
as_base64 = base64.b64encode(await fd.read()).decode("utf-8")
|
as_base64 = base64.b64encode(await fd.read()).decode("utf-8")
|
||||||
|
|
||||||
async with session.post(
|
async with ctx.session.post(
|
||||||
f"{self.address}/tagger/v1/interrogate",
|
f"{self.address}/tagger/v1/interrogate",
|
||||||
json={
|
json={
|
||||||
"model": self.model_id,
|
"model": self.model_id,
|
||||||
|
@ -105,6 +105,13 @@ class SDInterrogator(Interrogator):
|
||||||
return " ".join(upstream_tags)
|
return " ".join(upstream_tags)
|
||||||
|
|
||||||
|
|
||||||
|
class ControlInterrogator(Interrogator):
|
||||||
|
async def interrogate(self, ctx, path):
|
||||||
|
md5_hash = Path(path).stem
|
||||||
|
post = await fetch_post(ctx, md5_hash)
|
||||||
|
return post["tag_string"]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Config:
|
class Config:
|
||||||
sd_webui_address: str
|
sd_webui_address: str
|
||||||
|
@ -114,10 +121,14 @@ class Config:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def all_available_models(self) -> List[Any]:
|
def all_available_models(self) -> List[Any]:
|
||||||
return [
|
return (
|
||||||
SDInterrogator(sd_interrogator, self.sd_webui_address)
|
[
|
||||||
for sd_interrogator in self.sd_webui_models
|
SDInterrogator(sd_interrogator, self.sd_webui_address)
|
||||||
] + [DDInterrogator(self.dd_model_name, self.dd_address)]
|
for sd_interrogator in self.sd_webui_models
|
||||||
|
]
|
||||||
|
+ [DDInterrogator(self.dd_model_name, self.dd_address)]
|
||||||
|
+ [ControlInterrogator("control", None)]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -291,7 +302,7 @@ async def fight(ctx):
|
||||||
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
|
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
|
||||||
|
|
||||||
start_ts = time.monotonic()
|
start_ts = time.monotonic()
|
||||||
tag_string = await interrogator.interrogate(ctx.session, post_filepath)
|
tag_string = await interrogator.interrogate(ctx, post_filepath)
|
||||||
end_ts = time.monotonic()
|
end_ts = time.monotonic()
|
||||||
time_taken = round(end_ts - start_ts, 10)
|
time_taken = round(end_ts - start_ts, 10)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue