diff --git a/main.py b/main.py index 7bbf35e..69be1ad 100644 --- a/main.py +++ b/main.py @@ -17,7 +17,7 @@ import plotly.io as pio from collections import defaultdict, Counter from pathlib import Path from urllib.parse import urlparse -from typing import Any, Optional, List, Set, Tuple, Dict +from typing import Any, Optional, List, Set, Tuple from aiolimiter import AsyncLimiter from dataclasses import dataclass, field @@ -31,7 +31,7 @@ DEFAULTS = [ "wd14-convnext.v1", "wd14-convnext.v2", "wd14-convnextv2.v1", - # "wd14-swinv2-v1", + "wd14-swinv2-v1", "wd-v1-4-moat-tagger.v2", "mld-caformer.dec-5-97527", # broken model: "mld-tresnetd.6-30000", @@ -57,8 +57,7 @@ class Interrogator: "select output_tag_string, time_taken from interrogated_posts where md5 = ? and model_name = ?", (md5_hash, self.model_id), ) - if not rows: - raise AssertionError("run fight mode first, there are missing posts..") + assert len(rows) == 1 # run 'fight' mode, there's missing posts tag_string, time_taken = rows[0][0], rows[0][1] return InterrogatorPost(self._process(tag_string.split()), time_taken) @@ -107,17 +106,8 @@ class SDInterrogator(Interrogator): log.info("got %d", resp.status) assert resp.status == 200 data = await resp.json() - tags = [] - - for maybe_tag, maybe_weight in data["caption"].items(): - if isinstance(maybe_weight, float): - tags.append(maybe_tag) - elif isinstance(maybe_weight, dict): - for tag, weight in maybe_weight.items(): - assert isinstance(weight, float) - tags.append(tag) - else: - raise AssertionError(f"invalid weight: {maybe_weight!s}") + tags_with_scores = data["caption"] + tags = list(tags_with_scores.keys()) upstream_tags = [tag.replace(" ", "_") for tag in tags] return " ".join(upstream_tags) @@ -135,7 +125,6 @@ class Config: sd_webui_address: str dd_address: str dd_model_name: str - sd_webui_extras: Dict[str, str] sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS)) @property @@ -145,10 +134,6 @@ class Config: SDInterrogator(sd_interrogator, self.sd_webui_address) for sd_interrogator in self.sd_webui_models ] - + [ - SDInterrogator(sd_interrogator, url) - for sd_interrogator, url in self.sd_webui_extras.items() - ] + [DDInterrogator(self.dd_model_name, self.dd_address)] + [ControlInterrogator("control", None)] ) diff --git a/requirements.txt b/requirements.txt index cb2f5b8..6bb4e45 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -aiosqlite==0.20.0 -aiohttp==3.10.0 +aiosqlite==0.19.0 +aiohttp==3.8.4 aiolimiter>1.1.0<2.0 -aiofiles==24.1.0 +aiofiles==23.1.0 plotly>5.15.0<6.0 -pandas==2.2.2 -kaleido==0.2.1 +pandas==2.0.2 +kaleido==0.2.1 \ No newline at end of file