add camie-tagger and joytag

This commit is contained in:
Luna 2025-04-11 22:43:00 -03:00
parent 990b69c602
commit 90695c0310
2 changed files with 55 additions and 24 deletions

65
main.py
View file

@ -48,6 +48,7 @@ class InterrogatorPost:
class Interrogator:
model_id: str
address: str
threshold: float = 0.7
def _process(self, lst):
return lst
@ -96,15 +97,16 @@ class SDInterrogator(Interrogator):
async with aiofiles.open(path, "rb") as fd:
as_base64 = base64.b64encode(await fd.read()).decode("utf-8")
url = f"{self.address}/tagger/v1/interrogate"
async with ctx.session.post(
f"{self.address}/tagger/v1/interrogate",
url,
json={
"model": self.model_id,
"threshold": 0.7,
"threshold": self.threshold,
"image": as_base64,
},
) as resp:
log.info("got %d", resp.status)
log.info("%s got %d from %s", path, resp.status, url)
assert resp.status == 200
data = await resp.json()
tags = []
@ -136,6 +138,8 @@ class Config:
dd_address: str
dd_model_name: str
sd_webui_extras: Dict[str, str]
camie_address: str
joytag_address: str
sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS))
@property
@ -149,7 +153,11 @@ class Config:
SDInterrogator(sd_interrogator, url)
for sd_interrogator, url in self.sd_webui_extras.items()
]
+ [DDInterrogator(self.dd_model_name, self.dd_address)]
+ [
DDInterrogator(self.dd_model_name, self.dd_address),
SDInterrogator("camie-tagger-v1", self.camie_address, 0.325),
SDInterrogator("joytag-v1", self.joytag_address, 0.4),
]
+ [ControlInterrogator("control", None)]
)
@ -183,7 +191,13 @@ class Danbooru(Booru):
async def posts(self, tag_query: str, limit, page: int):
log.info("%s: submit %r", self.title, tag_query)
async with self.limiter:
log.info("%s: submit upstream %r", self.title, tag_query)
log.info(
"%s: submit upstream query=%r limit=%r page=%r",
self.title,
tag_query,
limit,
page,
)
async with self.session.get(
f"{self.base_url}/posts.json",
params={"tags": tag_query, "limit": limit, "page": page},
@ -305,6 +319,22 @@ async def insert_interrogated_result(
await ctx.db.commit()
async def process_hash(ctx, interrogator, missing_hash, semaphore, index, total):
async with semaphore:
log.info("interrogating %r (%d/%d)", missing_hash, index, total)
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
start_ts = time.monotonic()
tag_string = await interrogator.interrogate(ctx, post_filepath)
end_ts = time.monotonic()
time_taken = round(end_ts - start_ts, 10)
log.info("took %.5fsec, got %r", time_taken, tag_string)
await insert_interrogated_result(
ctx, interrogator, missing_hash, tag_string, time_taken
)
async def fight(ctx):
interrogators = ctx.config.all_available_models
@ -323,21 +353,22 @@ async def fight(ctx):
log.info("missing %d hashes", len(missing_hashes))
semaphore = asyncio.Semaphore(3)
tasks = []
for index, missing_hash in enumerate(missing_hashes):
log.info(
"interrogating %r (%d/%d)", missing_hash, index, len(missing_hashes)
task = process_hash(
ctx,
interrogator,
missing_hash,
semaphore,
index + 1,
len(missing_hashes),
)
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
tasks.append(task)
start_ts = time.monotonic()
tag_string = await interrogator.interrogate(ctx, post_filepath)
end_ts = time.monotonic()
time_taken = round(end_ts - start_ts, 10)
log.info("took %.5fsec, got %r", time_taken, tag_string)
await insert_interrogated_result(
ctx, interrogator, missing_hash, tag_string, time_taken
)
# Run all tasks concurrently with semaphore limiting to 3 at a time
await asyncio.gather(*tasks)
def score(

View file

@ -1,7 +1,7 @@
aiosqlite==0.20.0
aiohttp==3.10.0
aiolimiter>1.1.0<2.0
aiofiles==24.1.0
plotly>5.15.0<6.0
pandas==2.2.2
kaleido==0.2.1
aiosqlite
aiohttp
aiolimiter>1.1.0
aiofiles
plotly>5.15.0
pandas
kaleido