diff --git a/main.py b/main.py index fda6f09..7bbf35e 100644 --- a/main.py +++ b/main.py @@ -107,12 +107,17 @@ class SDInterrogator(Interrogator): log.info("got %d", resp.status) assert resp.status == 200 data = await resp.json() - tags_with_weights = data["caption"] - tags = [] - for tag, weight in tags_with_weights.items(): - tags.append(tag) + for maybe_tag, maybe_weight in data["caption"].items(): + if isinstance(maybe_weight, float): + tags.append(maybe_tag) + elif isinstance(maybe_weight, dict): + for tag, weight in maybe_weight.items(): + assert isinstance(weight, float) + tags.append(tag) + else: + raise AssertionError(f"invalid weight: {maybe_weight!s}") upstream_tags = [tag.replace(" ", "_") for tag in tags] return " ".join(upstream_tags)