support old and new wd tagger api

This commit is contained in:
Luna 2024-08-04 17:37:32 -03:00
parent a7258a1140
commit 634848bd65

13
main.py
View file

@ -107,12 +107,17 @@ class SDInterrogator(Interrogator):
log.info("got %d", resp.status)
assert resp.status == 200
data = await resp.json()
tags_with_weights = data["caption"]
tags = []
for tag, weight in tags_with_weights.items():
tags.append(tag)
for maybe_tag, maybe_weight in data["caption"].items():
if isinstance(maybe_weight, float):
tags.append(maybe_tag)
elif isinstance(maybe_weight, dict):
for tag, weight in maybe_weight.items():
assert isinstance(weight, float)
tags.append(tag)
else:
raise AssertionError(f"invalid weight: {maybe_weight!s}")
upstream_tags = [tag.replace(" ", "_") for tag in tags]
return " ".join(upstream_tags)