Compare commits
5 commits
4b8202f493
...
990b69c602
Author | SHA1 | Date | |
---|---|---|---|
990b69c602 | |||
634848bd65 | |||
a7258a1140 | |||
967813b429 | |||
8449d73675 |
2 changed files with 25 additions and 10 deletions
25
main.py
25
main.py
|
@ -17,7 +17,7 @@ import plotly.io as pio
|
|||
from collections import defaultdict, Counter
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
from typing import Any, Optional, List, Set, Tuple
|
||||
from typing import Any, Optional, List, Set, Tuple, Dict
|
||||
from aiolimiter import AsyncLimiter
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
@ -31,7 +31,7 @@ DEFAULTS = [
|
|||
"wd14-convnext.v1",
|
||||
"wd14-convnext.v2",
|
||||
"wd14-convnextv2.v1",
|
||||
"wd14-swinv2-v1",
|
||||
# "wd14-swinv2-v1",
|
||||
"wd-v1-4-moat-tagger.v2",
|
||||
"mld-caformer.dec-5-97527",
|
||||
# broken model: "mld-tresnetd.6-30000",
|
||||
|
@ -57,7 +57,8 @@ class Interrogator:
|
|||
"select output_tag_string, time_taken from interrogated_posts where md5 = ? and model_name = ?",
|
||||
(md5_hash, self.model_id),
|
||||
)
|
||||
assert len(rows) == 1 # run 'fight' mode, there's missing posts
|
||||
if not rows:
|
||||
raise AssertionError("run fight mode first, there are missing posts..")
|
||||
tag_string, time_taken = rows[0][0], rows[0][1]
|
||||
return InterrogatorPost(self._process(tag_string.split()), time_taken)
|
||||
|
||||
|
@ -106,8 +107,17 @@ class SDInterrogator(Interrogator):
|
|||
log.info("got %d", resp.status)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
tags_with_scores = data["caption"]
|
||||
tags = list(tags_with_scores.keys())
|
||||
tags = []
|
||||
|
||||
for maybe_tag, maybe_weight in data["caption"].items():
|
||||
if isinstance(maybe_weight, float):
|
||||
tags.append(maybe_tag)
|
||||
elif isinstance(maybe_weight, dict):
|
||||
for tag, weight in maybe_weight.items():
|
||||
assert isinstance(weight, float)
|
||||
tags.append(tag)
|
||||
else:
|
||||
raise AssertionError(f"invalid weight: {maybe_weight!s}")
|
||||
|
||||
upstream_tags = [tag.replace(" ", "_") for tag in tags]
|
||||
return " ".join(upstream_tags)
|
||||
|
@ -125,6 +135,7 @@ class Config:
|
|||
sd_webui_address: str
|
||||
dd_address: str
|
||||
dd_model_name: str
|
||||
sd_webui_extras: Dict[str, str]
|
||||
sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS))
|
||||
|
||||
@property
|
||||
|
@ -134,6 +145,10 @@ class Config:
|
|||
SDInterrogator(sd_interrogator, self.sd_webui_address)
|
||||
for sd_interrogator in self.sd_webui_models
|
||||
]
|
||||
+ [
|
||||
SDInterrogator(sd_interrogator, url)
|
||||
for sd_interrogator, url in self.sd_webui_extras.items()
|
||||
]
|
||||
+ [DDInterrogator(self.dd_model_name, self.dd_address)]
|
||||
+ [ControlInterrogator("control", None)]
|
||||
)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
aiosqlite==0.19.0
|
||||
aiohttp==3.8.4
|
||||
aiosqlite==0.20.0
|
||||
aiohttp==3.10.0
|
||||
aiolimiter>1.1.0<2.0
|
||||
aiofiles==23.1.0
|
||||
aiofiles==24.1.0
|
||||
plotly>5.15.0<6.0
|
||||
pandas==2.0.2
|
||||
kaleido==0.2.1
|
||||
pandas==2.2.2
|
||||
kaleido==0.2.1
|
||||
|
|
Loading…
Reference in a new issue