add camie-tagger and joytag
This commit is contained in:
parent
990b69c602
commit
90695c0310
2 changed files with 55 additions and 24 deletions
65
main.py
65
main.py
|
@ -48,6 +48,7 @@ class InterrogatorPost:
|
|||
class Interrogator:
|
||||
model_id: str
|
||||
address: str
|
||||
threshold: float = 0.7
|
||||
|
||||
def _process(self, lst):
|
||||
return lst
|
||||
|
@ -96,15 +97,16 @@ class SDInterrogator(Interrogator):
|
|||
async with aiofiles.open(path, "rb") as fd:
|
||||
as_base64 = base64.b64encode(await fd.read()).decode("utf-8")
|
||||
|
||||
url = f"{self.address}/tagger/v1/interrogate"
|
||||
async with ctx.session.post(
|
||||
f"{self.address}/tagger/v1/interrogate",
|
||||
url,
|
||||
json={
|
||||
"model": self.model_id,
|
||||
"threshold": 0.7,
|
||||
"threshold": self.threshold,
|
||||
"image": as_base64,
|
||||
},
|
||||
) as resp:
|
||||
log.info("got %d", resp.status)
|
||||
log.info("%s got %d from %s", path, resp.status, url)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
tags = []
|
||||
|
@ -136,6 +138,8 @@ class Config:
|
|||
dd_address: str
|
||||
dd_model_name: str
|
||||
sd_webui_extras: Dict[str, str]
|
||||
camie_address: str
|
||||
joytag_address: str
|
||||
sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS))
|
||||
|
||||
@property
|
||||
|
@ -149,7 +153,11 @@ class Config:
|
|||
SDInterrogator(sd_interrogator, url)
|
||||
for sd_interrogator, url in self.sd_webui_extras.items()
|
||||
]
|
||||
+ [DDInterrogator(self.dd_model_name, self.dd_address)]
|
||||
+ [
|
||||
DDInterrogator(self.dd_model_name, self.dd_address),
|
||||
SDInterrogator("camie-tagger-v1", self.camie_address, 0.325),
|
||||
SDInterrogator("joytag-v1", self.joytag_address, 0.4),
|
||||
]
|
||||
+ [ControlInterrogator("control", None)]
|
||||
)
|
||||
|
||||
|
@ -183,7 +191,13 @@ class Danbooru(Booru):
|
|||
async def posts(self, tag_query: str, limit, page: int):
|
||||
log.info("%s: submit %r", self.title, tag_query)
|
||||
async with self.limiter:
|
||||
log.info("%s: submit upstream %r", self.title, tag_query)
|
||||
log.info(
|
||||
"%s: submit upstream query=%r limit=%r page=%r",
|
||||
self.title,
|
||||
tag_query,
|
||||
limit,
|
||||
page,
|
||||
)
|
||||
async with self.session.get(
|
||||
f"{self.base_url}/posts.json",
|
||||
params={"tags": tag_query, "limit": limit, "page": page},
|
||||
|
@ -305,6 +319,22 @@ async def insert_interrogated_result(
|
|||
await ctx.db.commit()
|
||||
|
||||
|
||||
async def process_hash(ctx, interrogator, missing_hash, semaphore, index, total):
|
||||
async with semaphore:
|
||||
log.info("interrogating %r (%d/%d)", missing_hash, index, total)
|
||||
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
|
||||
|
||||
start_ts = time.monotonic()
|
||||
tag_string = await interrogator.interrogate(ctx, post_filepath)
|
||||
end_ts = time.monotonic()
|
||||
time_taken = round(end_ts - start_ts, 10)
|
||||
|
||||
log.info("took %.5fsec, got %r", time_taken, tag_string)
|
||||
await insert_interrogated_result(
|
||||
ctx, interrogator, missing_hash, tag_string, time_taken
|
||||
)
|
||||
|
||||
|
||||
async def fight(ctx):
|
||||
interrogators = ctx.config.all_available_models
|
||||
|
||||
|
@ -323,21 +353,22 @@ async def fight(ctx):
|
|||
|
||||
log.info("missing %d hashes", len(missing_hashes))
|
||||
|
||||
semaphore = asyncio.Semaphore(3)
|
||||
|
||||
tasks = []
|
||||
for index, missing_hash in enumerate(missing_hashes):
|
||||
log.info(
|
||||
"interrogating %r (%d/%d)", missing_hash, index, len(missing_hashes)
|
||||
task = process_hash(
|
||||
ctx,
|
||||
interrogator,
|
||||
missing_hash,
|
||||
semaphore,
|
||||
index + 1,
|
||||
len(missing_hashes),
|
||||
)
|
||||
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
|
||||
tasks.append(task)
|
||||
|
||||
start_ts = time.monotonic()
|
||||
tag_string = await interrogator.interrogate(ctx, post_filepath)
|
||||
end_ts = time.monotonic()
|
||||
time_taken = round(end_ts - start_ts, 10)
|
||||
|
||||
log.info("took %.5fsec, got %r", time_taken, tag_string)
|
||||
await insert_interrogated_result(
|
||||
ctx, interrogator, missing_hash, tag_string, time_taken
|
||||
)
|
||||
# Run all tasks concurrently with semaphore limiting to 3 at a time
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
def score(
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
aiosqlite==0.20.0
|
||||
aiohttp==3.10.0
|
||||
aiolimiter>1.1.0<2.0
|
||||
aiofiles==24.1.0
|
||||
plotly>5.15.0<6.0
|
||||
pandas==2.2.2
|
||||
kaleido==0.2.1
|
||||
aiosqlite
|
||||
aiohttp
|
||||
aiolimiter>1.1.0
|
||||
aiofiles
|
||||
plotly>5.15.0
|
||||
pandas
|
||||
kaleido
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue