Compare commits

...

5 commits

Author SHA1 Message Date
990b69c602 update deps 2024-08-04 17:37:40 -03:00
634848bd65 support old and new wd tagger api 2024-08-04 17:37:32 -03:00
a7258a1140 let extra interrogator models be configured in 2024-08-04 17:37:17 -03:00
967813b429 remove wd14-swinv2-v1, it brokey 2024-08-04 17:37:02 -03:00
8449d73675 fix? 2024-08-04 16:12:34 -03:00
2 changed files with 25 additions and 10 deletions

25
main.py
View file

@ -17,7 +17,7 @@ import plotly.io as pio
from collections import defaultdict, Counter
from pathlib import Path
from urllib.parse import urlparse
from typing import Any, Optional, List, Set, Tuple
from typing import Any, Optional, List, Set, Tuple, Dict
from aiolimiter import AsyncLimiter
from dataclasses import dataclass, field
@ -31,7 +31,7 @@ DEFAULTS = [
"wd14-convnext.v1",
"wd14-convnext.v2",
"wd14-convnextv2.v1",
"wd14-swinv2-v1",
# "wd14-swinv2-v1",
"wd-v1-4-moat-tagger.v2",
"mld-caformer.dec-5-97527",
# broken model: "mld-tresnetd.6-30000",
@ -57,7 +57,8 @@ class Interrogator:
"select output_tag_string, time_taken from interrogated_posts where md5 = ? and model_name = ?",
(md5_hash, self.model_id),
)
assert len(rows) == 1 # run 'fight' mode, there's missing posts
if not rows:
raise AssertionError("run fight mode first, there are missing posts..")
tag_string, time_taken = rows[0][0], rows[0][1]
return InterrogatorPost(self._process(tag_string.split()), time_taken)
@ -106,8 +107,17 @@ class SDInterrogator(Interrogator):
log.info("got %d", resp.status)
assert resp.status == 200
data = await resp.json()
tags_with_scores = data["caption"]
tags = list(tags_with_scores.keys())
tags = []
for maybe_tag, maybe_weight in data["caption"].items():
if isinstance(maybe_weight, float):
tags.append(maybe_tag)
elif isinstance(maybe_weight, dict):
for tag, weight in maybe_weight.items():
assert isinstance(weight, float)
tags.append(tag)
else:
raise AssertionError(f"invalid weight: {maybe_weight!s}")
upstream_tags = [tag.replace(" ", "_") for tag in tags]
return " ".join(upstream_tags)
@ -125,6 +135,7 @@ class Config:
sd_webui_address: str
dd_address: str
dd_model_name: str
sd_webui_extras: Dict[str, str]
sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS))
@property
@ -134,6 +145,10 @@ class Config:
SDInterrogator(sd_interrogator, self.sd_webui_address)
for sd_interrogator in self.sd_webui_models
]
+ [
SDInterrogator(sd_interrogator, url)
for sd_interrogator, url in self.sd_webui_extras.items()
]
+ [DDInterrogator(self.dd_model_name, self.dd_address)]
+ [ControlInterrogator("control", None)]
)

View file

@ -1,7 +1,7 @@
aiosqlite==0.19.0
aiohttp==3.8.4
aiosqlite==0.20.0
aiohttp==3.10.0
aiolimiter>1.1.0<2.0
aiofiles==23.1.0
aiofiles==24.1.0
plotly>5.15.0<6.0
pandas==2.0.2
kaleido==0.2.1
pandas==2.2.2
kaleido==0.2.1