remove rating tags from scoring
This commit is contained in:
parent
2550320f9d
commit
a73a3f26ce
1 changed files with 32 additions and 21 deletions
53
main.py
53
main.py
|
@ -37,6 +37,8 @@ DEFAULTS = [
|
|||
# broken model: "mld-tresnetd.6-30000",
|
||||
]
|
||||
|
||||
RATING = ["general", "explicit", "questionable", "sensitive", "safe"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterrogatorPost:
|
||||
|
@ -70,11 +72,16 @@ class DDInterrogator(Interrogator):
|
|||
new_lst = []
|
||||
for tag in lst:
|
||||
if tag.startswith("rating:"):
|
||||
original_danbooru_tag = tag.split(":")[1]
|
||||
continue
|
||||
else:
|
||||
original_danbooru_tag = tag
|
||||
|
||||
if original_danbooru_tag == "safe":
|
||||
original_danbooru_tag = "general"
|
||||
continue
|
||||
|
||||
if original_danbooru_tag in RATING:
|
||||
continue
|
||||
|
||||
new_lst.append(original_danbooru_tag)
|
||||
return new_lst
|
||||
|
||||
|
@ -97,10 +104,10 @@ class SDInterrogator(Interrogator):
|
|||
def _process(self, lst):
|
||||
new_lst = []
|
||||
for tag in lst:
|
||||
if self._fucked_rating and tag.startswith("rating_"):
|
||||
_, rating = tag.split("_")
|
||||
# remap vocabs for fucked vocabs
|
||||
original_danbooru_tag = rating
|
||||
if tag.startswith("rating_"):
|
||||
continue
|
||||
elif tag in RATING:
|
||||
continue
|
||||
else:
|
||||
original_danbooru_tag = tag
|
||||
new_lst.append(original_danbooru_tag)
|
||||
|
@ -138,11 +145,27 @@ class SDInterrogator(Interrogator):
|
|||
return " ".join(upstream_tags)
|
||||
|
||||
|
||||
def tag_string_for(post: dict) -> str:
|
||||
return (
|
||||
post["tag_string_general"]
|
||||
+ " "
|
||||
+ post["tag_string_copyright"]
|
||||
+ " "
|
||||
+ post["tag_string_character"]
|
||||
)
|
||||
|
||||
|
||||
class ControlInterrogator(Interrogator):
|
||||
async def fetch(self, ctx, path):
|
||||
md5_hash = Path(path).stem
|
||||
post = await fetch_post(ctx, md5_hash)
|
||||
tag_string = tag_string_for(post)
|
||||
return InterrogatorPost(tag_string.split(), 0)
|
||||
|
||||
async def interrogate(self, ctx, path):
|
||||
md5_hash = Path(path).stem
|
||||
post = await fetch_post(ctx, md5_hash)
|
||||
return post["tag_string"]
|
||||
return tag_string_for(post)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -294,19 +317,7 @@ async def fetch_post(ctx, md5) -> Optional[dict]:
|
|||
return None
|
||||
assert len(rows) == 1
|
||||
post = json.loads(rows[0][0])
|
||||
post_rating = post["rating"]
|
||||
match post_rating:
|
||||
case "g":
|
||||
rating_tag = "general"
|
||||
case "s":
|
||||
rating_tag = "sensitive"
|
||||
case "q":
|
||||
rating_tag = "questionable"
|
||||
case "e":
|
||||
rating_tag = "explicit"
|
||||
case _:
|
||||
raise AssertionError("invalid post rating {post_rating!r}")
|
||||
post["tag_string"] = post["tag_string"] + " " + rating_tag
|
||||
post["tag_string"] = tag_string_for(post)
|
||||
return post
|
||||
|
||||
|
||||
|
@ -355,7 +366,7 @@ async def fight(ctx):
|
|||
all_hashes = set(r[0] for r in all_rows)
|
||||
|
||||
for interrogator in interrogators:
|
||||
log.info("processing for %r", interrogator)
|
||||
log.info("processing fight for %r", interrogator)
|
||||
# calculate set of images we didn't interrogate yet
|
||||
interrogated_rows = await ctx.db.execute_fetchall(
|
||||
"select md5 from interrogated_posts where model_name = ?",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue