diff --git a/main.py b/main.py index 69be1ad..7bbf35e 100644 --- a/main.py +++ b/main.py @@ -17,7 +17,7 @@ import plotly.io as pio from collections import defaultdict, Counter from pathlib import Path from urllib.parse import urlparse -from typing import Any, Optional, List, Set, Tuple +from typing import Any, Optional, List, Set, Tuple, Dict from aiolimiter import AsyncLimiter from dataclasses import dataclass, field @@ -31,7 +31,7 @@ DEFAULTS = [ "wd14-convnext.v1", "wd14-convnext.v2", "wd14-convnextv2.v1", - "wd14-swinv2-v1", + # "wd14-swinv2-v1", "wd-v1-4-moat-tagger.v2", "mld-caformer.dec-5-97527", # broken model: "mld-tresnetd.6-30000", @@ -57,7 +57,8 @@ class Interrogator: "select output_tag_string, time_taken from interrogated_posts where md5 = ? and model_name = ?", (md5_hash, self.model_id), ) - assert len(rows) == 1 # run 'fight' mode, there's missing posts + if not rows: + raise AssertionError("run fight mode first, there are missing posts..") tag_string, time_taken = rows[0][0], rows[0][1] return InterrogatorPost(self._process(tag_string.split()), time_taken) @@ -106,8 +107,17 @@ class SDInterrogator(Interrogator): log.info("got %d", resp.status) assert resp.status == 200 data = await resp.json() - tags_with_scores = data["caption"] - tags = list(tags_with_scores.keys()) + tags = [] + + for maybe_tag, maybe_weight in data["caption"].items(): + if isinstance(maybe_weight, float): + tags.append(maybe_tag) + elif isinstance(maybe_weight, dict): + for tag, weight in maybe_weight.items(): + assert isinstance(weight, float) + tags.append(tag) + else: + raise AssertionError(f"invalid weight: {maybe_weight!s}") upstream_tags = [tag.replace(" ", "_") for tag in tags] return " ".join(upstream_tags) @@ -125,6 +135,7 @@ class Config: sd_webui_address: str dd_address: str dd_model_name: str + sd_webui_extras: Dict[str, str] sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS)) @property @@ -134,6 +145,10 @@ class Config: SDInterrogator(sd_interrogator, self.sd_webui_address) for sd_interrogator in self.sd_webui_models ] + + [ + SDInterrogator(sd_interrogator, url) + for sd_interrogator, url in self.sd_webui_extras.items() + ] + [DDInterrogator(self.dd_model_name, self.dd_address)] + [ControlInterrogator("control", None)] ) diff --git a/requirements.txt b/requirements.txt index 6bb4e45..cb2f5b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -aiosqlite==0.19.0 -aiohttp==3.8.4 +aiosqlite==0.20.0 +aiohttp==3.10.0 aiolimiter>1.1.0<2.0 -aiofiles==23.1.0 +aiofiles==24.1.0 plotly>5.15.0<6.0 -pandas==2.0.2 -kaleido==0.2.1 \ No newline at end of file +pandas==2.2.2 +kaleido==0.2.1