Compare commits

..

3 commits

Author SHA1 Message Date
099c7f8fc9 add some basic analysis about final scores 2023-06-10 01:30:48 -03:00
f3ec3ab6d3 fix extension check 2023-06-10 01:30:40 -03:00
919bd4017e fix deepdanbooru's rating tags 2023-06-10 01:30:26 -03:00

105
main.py
View file

@ -1,3 +1,4 @@
import os
import json
import aiohttp
import pprint
@ -8,10 +9,10 @@ import asyncio
import aiosqlite
import base64
import logging
from collections import defaultdict
from collections import defaultdict, Counter
from pathlib import Path
from urllib.parse import urlparse
from typing import Any, Optional, List, Set
from typing import Any, Optional, List, Set, Tuple
from aiolimiter import AsyncLimiter
from dataclasses import dataclass, field
@ -33,8 +34,32 @@ class Interrogator:
model_id: str
address: str
def _process(self, lst):
return lst
async def fetch_tags(self, ctx, md5_hash):
rows = await ctx.db.execute_fetchall(
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
(md5_hash, self.model_id),
)
assert len(rows) == 1 # run 'fight' mode, there's missing posts
tag_string = rows[0][0]
return self._process(tag_string.split())
class DDInterrogator(Interrogator):
def _process(self, lst):
new_lst = []
for tag in lst:
if tag.startswith("rating:"):
original_danbooru_tag = tag.split(":")[1]
else:
original_danbooru_tag = tag
if original_danbooru_tag == "safe":
continue
new_lst.append(original_danbooru_tag)
return new_lst
async def interrogate(self, session, path):
async with session.post(
f"{self.address}/",
@ -163,6 +188,7 @@ async def download_images(ctx):
log.info("processing post %r", md5)
existing_post = await fetch_post(ctx, md5)
if existing_post:
log.info("already exists %r", md5)
continue
# download the post
@ -170,10 +196,10 @@ async def download_images(ctx):
post_url_path = Path(urlparse(post_file_url).path)
good_extension = False
for extension in ALLOWED_EXTENSIONS:
if extension in post_url_path.stem:
if extension in post_url_path.suffix:
good_extension = True
if not good_extension:
log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.stem)
log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.suffix)
continue
post_filename = post_url_path.name
post_filepath = DOWNLOADS / post_filename
@ -263,12 +289,19 @@ async def fight(ctx):
)
def score(danbooru_tags: Set[str], interrogator_tags: Set[str]) -> decimal.Decimal:
def score(
danbooru_tags: Set[str], interrogator_tags: Set[str]
) -> Tuple[decimal.Decimal, Set[str]]:
tags_in_danbooru = danbooru_tags.intersection(interrogator_tags)
tags_not_in_danbooru = interrogator_tags - danbooru_tags
return decimal.Decimal(
len(tags_in_danbooru) - len(tags_not_in_danbooru)
) / decimal.Decimal(len(danbooru_tags))
return (
round(
decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru))
/ decimal.Decimal(len(danbooru_tags)),
10,
),
tags_not_in_danbooru,
)
async def scores(ctx):
@ -277,28 +310,36 @@ async def scores(ctx):
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
all_hashes = set(r[0] for r in all_rows)
absolute_scores = defaultdict(decimal.Decimal)
# absolute_scores = defaultdict(decimal.Decimal)
model_scores = defaultdict(dict)
incorrect_tags_counters = defaultdict(Counter)
for md5_hash in all_hashes:
log.info("processing for %r", md5_hash)
post = await fetch_post(ctx, md5_hash)
danbooru_tags = set(post["tag_string"].split())
for interrogator in interrogators:
rows = await ctx.db.execute_fetchall(
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
(md5_hash, interrogator.model_id),
)
assert len(rows) == 1 # run 'fight' mode, there's missing posts
tag_string = rows[0][0]
interrogator_tags = set(tag_string.split())
absolute_scores[interrogator.model_id] += score(
danbooru_tags, interrogator_tags
)
interrogator_tags = set(await interrogator.fetch_tags(ctx, md5_hash))
tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags)
for tag in incorrect_tags:
incorrect_tags_counters[interrogator.model_id][tag] += 1
model_scores[interrogator.model_id][md5_hash] = {
"score": tagging_score,
"incorrect_tags": incorrect_tags,
}
summed_scores = {
model_id: sum(d["score"] for d in post_scores.values())
for model_id, post_scores in model_scores.items()
}
normalized_scores = {
model: absolute_scores[model] / len(all_hashes) for model in absolute_scores
model: summed_scores[model] / len(all_hashes) for model in summed_scores
}
print("scores are [worst...best]")
for model in sorted(
normalized_scores.keys(),
key=lambda model: normalized_scores[model],
@ -306,6 +347,30 @@ async def scores(ctx):
):
print(model, normalized_scores[model])
print("[", end="")
for bad_md5_hash in sorted(
model_scores[model].keys(),
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
)[:4]:
data = model_scores[model][bad_md5_hash]
if os.getenv("DEBUG", "0") == "1":
print(md5_hash, data["score"], " ".join(data["incorrect_tags"]))
else:
print(data["score"], end=",")
print("...", end="")
for good_md5_hash in sorted(
model_scores[model].keys(),
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
reverse=True,
)[:4]:
data = model_scores[model][good_md5_hash]
print(data["score"], end=",")
print("]")
print("most incorrect tags", incorrect_tags_counters[model].most_common(5))
async def realmain(ctx):
await ctx.db.executescript(