From 634848bd652de63ac725c286cbb5e5700c1eb73e Mon Sep 17 00:00:00 2001 From: Luna Date: Sun, 4 Aug 2024 17:37:32 -0300 Subject: [PATCH] support old and new wd tagger api --- main.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index fda6f09..7bbf35e 100644 --- a/main.py +++ b/main.py @@ -107,12 +107,17 @@ class SDInterrogator(Interrogator): log.info("got %d", resp.status) assert resp.status == 200 data = await resp.json() - tags_with_weights = data["caption"] - tags = [] - for tag, weight in tags_with_weights.items(): - tags.append(tag) + for maybe_tag, maybe_weight in data["caption"].items(): + if isinstance(maybe_weight, float): + tags.append(maybe_tag) + elif isinstance(maybe_weight, dict): + for tag, weight in maybe_weight.items(): + assert isinstance(weight, float) + tags.append(tag) + else: + raise AssertionError(f"invalid weight: {maybe_weight!s}") upstream_tags = [tag.replace(" ", "_") for tag in tags] return " ".join(upstream_tags)