remove rating tags from scoring

This commit is contained in:
Luna 2025-04-12 00:47:48 -03:00
parent 2550320f9d
commit a73a3f26ce

53
main.py
View file

@ -37,6 +37,8 @@ DEFAULTS = [
# broken model: "mld-tresnetd.6-30000",
]
RATING = ["general", "explicit", "questionable", "sensitive", "safe"]
@dataclass
class InterrogatorPost:
@ -70,11 +72,16 @@ class DDInterrogator(Interrogator):
new_lst = []
for tag in lst:
if tag.startswith("rating:"):
original_danbooru_tag = tag.split(":")[1]
continue
else:
original_danbooru_tag = tag
if original_danbooru_tag == "safe":
original_danbooru_tag = "general"
continue
if original_danbooru_tag in RATING:
continue
new_lst.append(original_danbooru_tag)
return new_lst
@ -97,10 +104,10 @@ class SDInterrogator(Interrogator):
def _process(self, lst):
new_lst = []
for tag in lst:
if self._fucked_rating and tag.startswith("rating_"):
_, rating = tag.split("_")
# remap vocabs for fucked vocabs
original_danbooru_tag = rating
if tag.startswith("rating_"):
continue
elif tag in RATING:
continue
else:
original_danbooru_tag = tag
new_lst.append(original_danbooru_tag)
@ -138,11 +145,27 @@ class SDInterrogator(Interrogator):
return " ".join(upstream_tags)
def tag_string_for(post: dict) -> str:
return (
post["tag_string_general"]
+ " "
+ post["tag_string_copyright"]
+ " "
+ post["tag_string_character"]
)
class ControlInterrogator(Interrogator):
async def fetch(self, ctx, path):
md5_hash = Path(path).stem
post = await fetch_post(ctx, md5_hash)
tag_string = tag_string_for(post)
return InterrogatorPost(tag_string.split(), 0)
async def interrogate(self, ctx, path):
md5_hash = Path(path).stem
post = await fetch_post(ctx, md5_hash)
return post["tag_string"]
return tag_string_for(post)
@dataclass
@ -294,19 +317,7 @@ async def fetch_post(ctx, md5) -> Optional[dict]:
return None
assert len(rows) == 1
post = json.loads(rows[0][0])
post_rating = post["rating"]
match post_rating:
case "g":
rating_tag = "general"
case "s":
rating_tag = "sensitive"
case "q":
rating_tag = "questionable"
case "e":
rating_tag = "explicit"
case _:
raise AssertionError("invalid post rating {post_rating!r}")
post["tag_string"] = post["tag_string"] + " " + rating_tag
post["tag_string"] = tag_string_for(post)
return post
@ -355,7 +366,7 @@ async def fight(ctx):
all_hashes = set(r[0] for r in all_rows)
for interrogator in interrogators:
log.info("processing for %r", interrogator)
log.info("processing fight for %r", interrogator)
# calculate set of images we didn't interrogate yet
interrogated_rows = await ctx.db.execute_fetchall(
"select md5 from interrogated_posts where model_name = ?",