diff --git a/main.py b/main.py index 7dfee96..a6f0074 100644 --- a/main.py +++ b/main.py @@ -67,8 +67,8 @@ class DDInterrogator(Interrogator): new_lst.append(original_danbooru_tag) return new_lst - async def interrogate(self, session, path): - async with session.post( + async def interrogate(self, ctx, path): + async with ctx.session.post( f"{self.address}/", params={ "threshold": "0.7", @@ -83,11 +83,11 @@ class DDInterrogator(Interrogator): class SDInterrogator(Interrogator): - async def interrogate(self, session, path): + async def interrogate(self, ctx, path): async with aiofiles.open(path, "rb") as fd: as_base64 = base64.b64encode(await fd.read()).decode("utf-8") - async with session.post( + async with ctx.session.post( f"{self.address}/tagger/v1/interrogate", json={ "model": self.model_id, @@ -105,6 +105,13 @@ class SDInterrogator(Interrogator): return " ".join(upstream_tags) +class ControlInterrogator(Interrogator): + async def interrogate(self, ctx, path): + md5_hash = Path(path).stem + post = await fetch_post(ctx, md5_hash) + return post["tag_string"] + + @dataclass class Config: sd_webui_address: str @@ -114,10 +121,14 @@ class Config: @property def all_available_models(self) -> List[Any]: - return [ - SDInterrogator(sd_interrogator, self.sd_webui_address) - for sd_interrogator in self.sd_webui_models - ] + [DDInterrogator(self.dd_model_name, self.dd_address)] + return ( + [ + SDInterrogator(sd_interrogator, self.sd_webui_address) + for sd_interrogator in self.sd_webui_models + ] + + [DDInterrogator(self.dd_model_name, self.dd_address)] + + [ControlInterrogator("control", None)] + ) @dataclass @@ -291,7 +302,7 @@ async def fight(ctx): post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*")) start_ts = time.monotonic() - tag_string = await interrogator.interrogate(ctx.session, post_filepath) + tag_string = await interrogator.interrogate(ctx, post_filepath) end_ts = time.monotonic() time_taken = round(end_ts - start_ts, 10)