Compare commits

..

No commits in common. "990b69c602f632ebbc4a586ab82220faef27f2c9" and "4b8202f493b3c3f1dc828b212190ebbab934f7a3" have entirely different histories.

2 changed files with 10 additions and 25 deletions

25
main.py
View file

@ -17,7 +17,7 @@ import plotly.io as pio
from collections import defaultdict, Counter
from pathlib import Path
from urllib.parse import urlparse
from typing import Any, Optional, List, Set, Tuple, Dict
from typing import Any, Optional, List, Set, Tuple
from aiolimiter import AsyncLimiter
from dataclasses import dataclass, field
@ -31,7 +31,7 @@ DEFAULTS = [
"wd14-convnext.v1",
"wd14-convnext.v2",
"wd14-convnextv2.v1",
# "wd14-swinv2-v1",
"wd14-swinv2-v1",
"wd-v1-4-moat-tagger.v2",
"mld-caformer.dec-5-97527",
# broken model: "mld-tresnetd.6-30000",
@ -57,8 +57,7 @@ class Interrogator:
"select output_tag_string, time_taken from interrogated_posts where md5 = ? and model_name = ?",
(md5_hash, self.model_id),
)
if not rows:
raise AssertionError("run fight mode first, there are missing posts..")
assert len(rows) == 1 # run 'fight' mode, there's missing posts
tag_string, time_taken = rows[0][0], rows[0][1]
return InterrogatorPost(self._process(tag_string.split()), time_taken)
@ -107,17 +106,8 @@ class SDInterrogator(Interrogator):
log.info("got %d", resp.status)
assert resp.status == 200
data = await resp.json()
tags = []
for maybe_tag, maybe_weight in data["caption"].items():
if isinstance(maybe_weight, float):
tags.append(maybe_tag)
elif isinstance(maybe_weight, dict):
for tag, weight in maybe_weight.items():
assert isinstance(weight, float)
tags.append(tag)
else:
raise AssertionError(f"invalid weight: {maybe_weight!s}")
tags_with_scores = data["caption"]
tags = list(tags_with_scores.keys())
upstream_tags = [tag.replace(" ", "_") for tag in tags]
return " ".join(upstream_tags)
@ -135,7 +125,6 @@ class Config:
sd_webui_address: str
dd_address: str
dd_model_name: str
sd_webui_extras: Dict[str, str]
sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS))
@property
@ -145,10 +134,6 @@ class Config:
SDInterrogator(sd_interrogator, self.sd_webui_address)
for sd_interrogator in self.sd_webui_models
]
+ [
SDInterrogator(sd_interrogator, url)
for sd_interrogator, url in self.sd_webui_extras.items()
]
+ [DDInterrogator(self.dd_model_name, self.dd_address)]
+ [ControlInterrogator("control", None)]
)

View file

@ -1,7 +1,7 @@
aiosqlite==0.20.0
aiohttp==3.10.0
aiosqlite==0.19.0
aiohttp==3.8.4
aiolimiter>1.1.0<2.0
aiofiles==24.1.0
aiofiles==23.1.0
plotly>5.15.0<6.0
pandas==2.2.2
kaleido==0.2.1
pandas==2.0.2
kaleido==0.2.1