Compare commits

...

3 Commits

Author SHA1 Message Date
Luna 099c7f8fc9 add some basic analysis about final scores 2023-06-10 01:30:48 -03:00
Luna f3ec3ab6d3 fix extension check 2023-06-10 01:30:40 -03:00
Luna 919bd4017e fix deepdanbooru's rating tags 2023-06-10 01:30:26 -03:00
1 changed files with 85 additions and 20 deletions

105
main.py
View File

@ -1,3 +1,4 @@
import os
import json import json
import aiohttp import aiohttp
import pprint import pprint
@ -8,10 +9,10 @@ import asyncio
import aiosqlite import aiosqlite
import base64 import base64
import logging import logging
from collections import defaultdict from collections import defaultdict, Counter
from pathlib import Path from pathlib import Path
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Any, Optional, List, Set from typing import Any, Optional, List, Set, Tuple
from aiolimiter import AsyncLimiter from aiolimiter import AsyncLimiter
from dataclasses import dataclass, field from dataclasses import dataclass, field
@ -33,8 +34,32 @@ class Interrogator:
model_id: str model_id: str
address: str address: str
def _process(self, lst):
return lst
async def fetch_tags(self, ctx, md5_hash):
rows = await ctx.db.execute_fetchall(
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
(md5_hash, self.model_id),
)
assert len(rows) == 1 # run 'fight' mode, there's missing posts
tag_string = rows[0][0]
return self._process(tag_string.split())
class DDInterrogator(Interrogator): class DDInterrogator(Interrogator):
def _process(self, lst):
new_lst = []
for tag in lst:
if tag.startswith("rating:"):
original_danbooru_tag = tag.split(":")[1]
else:
original_danbooru_tag = tag
if original_danbooru_tag == "safe":
continue
new_lst.append(original_danbooru_tag)
return new_lst
async def interrogate(self, session, path): async def interrogate(self, session, path):
async with session.post( async with session.post(
f"{self.address}/", f"{self.address}/",
@ -163,6 +188,7 @@ async def download_images(ctx):
log.info("processing post %r", md5) log.info("processing post %r", md5)
existing_post = await fetch_post(ctx, md5) existing_post = await fetch_post(ctx, md5)
if existing_post: if existing_post:
log.info("already exists %r", md5)
continue continue
# download the post # download the post
@ -170,10 +196,10 @@ async def download_images(ctx):
post_url_path = Path(urlparse(post_file_url).path) post_url_path = Path(urlparse(post_file_url).path)
good_extension = False good_extension = False
for extension in ALLOWED_EXTENSIONS: for extension in ALLOWED_EXTENSIONS:
if extension in post_url_path.stem: if extension in post_url_path.suffix:
good_extension = True good_extension = True
if not good_extension: if not good_extension:
log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.stem) log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.suffix)
continue continue
post_filename = post_url_path.name post_filename = post_url_path.name
post_filepath = DOWNLOADS / post_filename post_filepath = DOWNLOADS / post_filename
@ -263,12 +289,19 @@ async def fight(ctx):
) )
def score(danbooru_tags: Set[str], interrogator_tags: Set[str]) -> decimal.Decimal: def score(
danbooru_tags: Set[str], interrogator_tags: Set[str]
) -> Tuple[decimal.Decimal, Set[str]]:
tags_in_danbooru = danbooru_tags.intersection(interrogator_tags) tags_in_danbooru = danbooru_tags.intersection(interrogator_tags)
tags_not_in_danbooru = interrogator_tags - danbooru_tags tags_not_in_danbooru = interrogator_tags - danbooru_tags
return decimal.Decimal( return (
len(tags_in_danbooru) - len(tags_not_in_danbooru) round(
) / decimal.Decimal(len(danbooru_tags)) decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru))
/ decimal.Decimal(len(danbooru_tags)),
10,
),
tags_not_in_danbooru,
)
async def scores(ctx): async def scores(ctx):
@ -277,28 +310,36 @@ async def scores(ctx):
all_rows = await ctx.db.execute_fetchall("select md5 from posts") all_rows = await ctx.db.execute_fetchall("select md5 from posts")
all_hashes = set(r[0] for r in all_rows) all_hashes = set(r[0] for r in all_rows)
absolute_scores = defaultdict(decimal.Decimal) # absolute_scores = defaultdict(decimal.Decimal)
model_scores = defaultdict(dict)
incorrect_tags_counters = defaultdict(Counter)
for md5_hash in all_hashes: for md5_hash in all_hashes:
log.info("processing for %r", md5_hash) log.info("processing for %r", md5_hash)
post = await fetch_post(ctx, md5_hash) post = await fetch_post(ctx, md5_hash)
danbooru_tags = set(post["tag_string"].split()) danbooru_tags = set(post["tag_string"].split())
for interrogator in interrogators: for interrogator in interrogators:
rows = await ctx.db.execute_fetchall( interrogator_tags = set(await interrogator.fetch_tags(ctx, md5_hash))
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?", tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags)
(md5_hash, interrogator.model_id), for tag in incorrect_tags:
) incorrect_tags_counters[interrogator.model_id][tag] += 1
assert len(rows) == 1 # run 'fight' mode, there's missing posts
tag_string = rows[0][0] model_scores[interrogator.model_id][md5_hash] = {
interrogator_tags = set(tag_string.split()) "score": tagging_score,
absolute_scores[interrogator.model_id] += score( "incorrect_tags": incorrect_tags,
danbooru_tags, interrogator_tags }
)
summed_scores = {
model_id: sum(d["score"] for d in post_scores.values())
for model_id, post_scores in model_scores.items()
}
normalized_scores = { normalized_scores = {
model: absolute_scores[model] / len(all_hashes) for model in absolute_scores model: summed_scores[model] / len(all_hashes) for model in summed_scores
} }
print("scores are [worst...best]")
for model in sorted( for model in sorted(
normalized_scores.keys(), normalized_scores.keys(),
key=lambda model: normalized_scores[model], key=lambda model: normalized_scores[model],
@ -306,6 +347,30 @@ async def scores(ctx):
): ):
print(model, normalized_scores[model]) print(model, normalized_scores[model])
print("[", end="")
for bad_md5_hash in sorted(
model_scores[model].keys(),
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
)[:4]:
data = model_scores[model][bad_md5_hash]
if os.getenv("DEBUG", "0") == "1":
print(md5_hash, data["score"], " ".join(data["incorrect_tags"]))
else:
print(data["score"], end=",")
print("...", end="")
for good_md5_hash in sorted(
model_scores[model].keys(),
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
reverse=True,
)[:4]:
data = model_scores[model][good_md5_hash]
print(data["score"], end=",")
print("]")
print("most incorrect tags", incorrect_tags_counters[model].most_common(5))
async def realmain(ctx): async def realmain(ctx):
await ctx.db.executescript( await ctx.db.executescript(