Compare commits

..

No commits in common. "099c7f8fc9b31cc3c2de4a832ab745c46a74ca2d" and "18da5d7972b96f0b8b45b04dceb47e49c268fe38" have entirely different histories.

1 changed files with 20 additions and 85 deletions

105
main.py
View File

@ -1,4 +1,3 @@
import os
import json
import aiohttp
import pprint
@ -9,10 +8,10 @@ import asyncio
import aiosqlite
import base64
import logging
from collections import defaultdict, Counter
from collections import defaultdict
from pathlib import Path
from urllib.parse import urlparse
from typing import Any, Optional, List, Set, Tuple
from typing import Any, Optional, List, Set
from aiolimiter import AsyncLimiter
from dataclasses import dataclass, field
@ -34,32 +33,8 @@ class Interrogator:
model_id: str
address: str
def _process(self, lst):
return lst
async def fetch_tags(self, ctx, md5_hash):
rows = await ctx.db.execute_fetchall(
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
(md5_hash, self.model_id),
)
assert len(rows) == 1 # run 'fight' mode, there's missing posts
tag_string = rows[0][0]
return self._process(tag_string.split())
class DDInterrogator(Interrogator):
def _process(self, lst):
new_lst = []
for tag in lst:
if tag.startswith("rating:"):
original_danbooru_tag = tag.split(":")[1]
else:
original_danbooru_tag = tag
if original_danbooru_tag == "safe":
continue
new_lst.append(original_danbooru_tag)
return new_lst
async def interrogate(self, session, path):
async with session.post(
f"{self.address}/",
@ -188,7 +163,6 @@ async def download_images(ctx):
log.info("processing post %r", md5)
existing_post = await fetch_post(ctx, md5)
if existing_post:
log.info("already exists %r", md5)
continue
# download the post
@ -196,10 +170,10 @@ async def download_images(ctx):
post_url_path = Path(urlparse(post_file_url).path)
good_extension = False
for extension in ALLOWED_EXTENSIONS:
if extension in post_url_path.suffix:
if extension in post_url_path.stem:
good_extension = True
if not good_extension:
log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.suffix)
log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.stem)
continue
post_filename = post_url_path.name
post_filepath = DOWNLOADS / post_filename
@ -289,19 +263,12 @@ async def fight(ctx):
)
def score(
danbooru_tags: Set[str], interrogator_tags: Set[str]
) -> Tuple[decimal.Decimal, Set[str]]:
def score(danbooru_tags: Set[str], interrogator_tags: Set[str]) -> decimal.Decimal:
tags_in_danbooru = danbooru_tags.intersection(interrogator_tags)
tags_not_in_danbooru = interrogator_tags - danbooru_tags
return (
round(
decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru))
/ decimal.Decimal(len(danbooru_tags)),
10,
),
tags_not_in_danbooru,
)
return decimal.Decimal(
len(tags_in_danbooru) - len(tags_not_in_danbooru)
) / decimal.Decimal(len(danbooru_tags))
async def scores(ctx):
@ -310,36 +277,28 @@ async def scores(ctx):
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
all_hashes = set(r[0] for r in all_rows)
# absolute_scores = defaultdict(decimal.Decimal)
model_scores = defaultdict(dict)
incorrect_tags_counters = defaultdict(Counter)
absolute_scores = defaultdict(decimal.Decimal)
for md5_hash in all_hashes:
log.info("processing for %r", md5_hash)
post = await fetch_post(ctx, md5_hash)
danbooru_tags = set(post["tag_string"].split())
for interrogator in interrogators:
interrogator_tags = set(await interrogator.fetch_tags(ctx, md5_hash))
tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags)
for tag in incorrect_tags:
incorrect_tags_counters[interrogator.model_id][tag] += 1
model_scores[interrogator.model_id][md5_hash] = {
"score": tagging_score,
"incorrect_tags": incorrect_tags,
}
summed_scores = {
model_id: sum(d["score"] for d in post_scores.values())
for model_id, post_scores in model_scores.items()
}
rows = await ctx.db.execute_fetchall(
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
(md5_hash, interrogator.model_id),
)
assert len(rows) == 1 # run 'fight' mode, there's missing posts
tag_string = rows[0][0]
interrogator_tags = set(tag_string.split())
absolute_scores[interrogator.model_id] += score(
danbooru_tags, interrogator_tags
)
normalized_scores = {
model: summed_scores[model] / len(all_hashes) for model in summed_scores
model: absolute_scores[model] / len(all_hashes) for model in absolute_scores
}
print("scores are [worst...best]")
for model in sorted(
normalized_scores.keys(),
key=lambda model: normalized_scores[model],
@ -347,30 +306,6 @@ async def scores(ctx):
):
print(model, normalized_scores[model])
print("[", end="")
for bad_md5_hash in sorted(
model_scores[model].keys(),
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
)[:4]:
data = model_scores[model][bad_md5_hash]
if os.getenv("DEBUG", "0") == "1":
print(md5_hash, data["score"], " ".join(data["incorrect_tags"]))
else:
print(data["score"], end=",")
print("...", end="")
for good_md5_hash in sorted(
model_scores[model].keys(),
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
reverse=True,
)[:4]:
data = model_scores[model][good_md5_hash]
print(data["score"], end=",")
print("]")
print("most incorrect tags", incorrect_tags_counters[model].most_common(5))
async def realmain(ctx):
await ctx.db.executescript(