Compare commits
3 commits
18da5d7972
...
099c7f8fc9
Author | SHA1 | Date | |
---|---|---|---|
099c7f8fc9 | |||
f3ec3ab6d3 | |||
919bd4017e |
1 changed files with 85 additions and 20 deletions
105
main.py
105
main.py
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import json
|
||||
import aiohttp
|
||||
import pprint
|
||||
|
@ -8,10 +9,10 @@ import asyncio
|
|||
import aiosqlite
|
||||
import base64
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, Counter
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
from typing import Any, Optional, List, Set
|
||||
from typing import Any, Optional, List, Set, Tuple
|
||||
from aiolimiter import AsyncLimiter
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
@ -33,8 +34,32 @@ class Interrogator:
|
|||
model_id: str
|
||||
address: str
|
||||
|
||||
def _process(self, lst):
|
||||
return lst
|
||||
|
||||
async def fetch_tags(self, ctx, md5_hash):
|
||||
rows = await ctx.db.execute_fetchall(
|
||||
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
|
||||
(md5_hash, self.model_id),
|
||||
)
|
||||
assert len(rows) == 1 # run 'fight' mode, there's missing posts
|
||||
tag_string = rows[0][0]
|
||||
return self._process(tag_string.split())
|
||||
|
||||
|
||||
class DDInterrogator(Interrogator):
|
||||
def _process(self, lst):
|
||||
new_lst = []
|
||||
for tag in lst:
|
||||
if tag.startswith("rating:"):
|
||||
original_danbooru_tag = tag.split(":")[1]
|
||||
else:
|
||||
original_danbooru_tag = tag
|
||||
if original_danbooru_tag == "safe":
|
||||
continue
|
||||
new_lst.append(original_danbooru_tag)
|
||||
return new_lst
|
||||
|
||||
async def interrogate(self, session, path):
|
||||
async with session.post(
|
||||
f"{self.address}/",
|
||||
|
@ -163,6 +188,7 @@ async def download_images(ctx):
|
|||
log.info("processing post %r", md5)
|
||||
existing_post = await fetch_post(ctx, md5)
|
||||
if existing_post:
|
||||
log.info("already exists %r", md5)
|
||||
continue
|
||||
|
||||
# download the post
|
||||
|
@ -170,10 +196,10 @@ async def download_images(ctx):
|
|||
post_url_path = Path(urlparse(post_file_url).path)
|
||||
good_extension = False
|
||||
for extension in ALLOWED_EXTENSIONS:
|
||||
if extension in post_url_path.stem:
|
||||
if extension in post_url_path.suffix:
|
||||
good_extension = True
|
||||
if not good_extension:
|
||||
log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.stem)
|
||||
log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.suffix)
|
||||
continue
|
||||
post_filename = post_url_path.name
|
||||
post_filepath = DOWNLOADS / post_filename
|
||||
|
@ -263,12 +289,19 @@ async def fight(ctx):
|
|||
)
|
||||
|
||||
|
||||
def score(danbooru_tags: Set[str], interrogator_tags: Set[str]) -> decimal.Decimal:
|
||||
def score(
|
||||
danbooru_tags: Set[str], interrogator_tags: Set[str]
|
||||
) -> Tuple[decimal.Decimal, Set[str]]:
|
||||
tags_in_danbooru = danbooru_tags.intersection(interrogator_tags)
|
||||
tags_not_in_danbooru = interrogator_tags - danbooru_tags
|
||||
return decimal.Decimal(
|
||||
len(tags_in_danbooru) - len(tags_not_in_danbooru)
|
||||
) / decimal.Decimal(len(danbooru_tags))
|
||||
return (
|
||||
round(
|
||||
decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru))
|
||||
/ decimal.Decimal(len(danbooru_tags)),
|
||||
10,
|
||||
),
|
||||
tags_not_in_danbooru,
|
||||
)
|
||||
|
||||
|
||||
async def scores(ctx):
|
||||
|
@ -277,28 +310,36 @@ async def scores(ctx):
|
|||
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
|
||||
all_hashes = set(r[0] for r in all_rows)
|
||||
|
||||
absolute_scores = defaultdict(decimal.Decimal)
|
||||
# absolute_scores = defaultdict(decimal.Decimal)
|
||||
model_scores = defaultdict(dict)
|
||||
incorrect_tags_counters = defaultdict(Counter)
|
||||
|
||||
for md5_hash in all_hashes:
|
||||
log.info("processing for %r", md5_hash)
|
||||
post = await fetch_post(ctx, md5_hash)
|
||||
danbooru_tags = set(post["tag_string"].split())
|
||||
for interrogator in interrogators:
|
||||
rows = await ctx.db.execute_fetchall(
|
||||
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
|
||||
(md5_hash, interrogator.model_id),
|
||||
)
|
||||
assert len(rows) == 1 # run 'fight' mode, there's missing posts
|
||||
tag_string = rows[0][0]
|
||||
interrogator_tags = set(tag_string.split())
|
||||
absolute_scores[interrogator.model_id] += score(
|
||||
danbooru_tags, interrogator_tags
|
||||
)
|
||||
interrogator_tags = set(await interrogator.fetch_tags(ctx, md5_hash))
|
||||
tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags)
|
||||
for tag in incorrect_tags:
|
||||
incorrect_tags_counters[interrogator.model_id][tag] += 1
|
||||
|
||||
model_scores[interrogator.model_id][md5_hash] = {
|
||||
"score": tagging_score,
|
||||
"incorrect_tags": incorrect_tags,
|
||||
}
|
||||
|
||||
summed_scores = {
|
||||
model_id: sum(d["score"] for d in post_scores.values())
|
||||
for model_id, post_scores in model_scores.items()
|
||||
}
|
||||
|
||||
normalized_scores = {
|
||||
model: absolute_scores[model] / len(all_hashes) for model in absolute_scores
|
||||
model: summed_scores[model] / len(all_hashes) for model in summed_scores
|
||||
}
|
||||
|
||||
print("scores are [worst...best]")
|
||||
|
||||
for model in sorted(
|
||||
normalized_scores.keys(),
|
||||
key=lambda model: normalized_scores[model],
|
||||
|
@ -306,6 +347,30 @@ async def scores(ctx):
|
|||
):
|
||||
print(model, normalized_scores[model])
|
||||
|
||||
print("[", end="")
|
||||
|
||||
for bad_md5_hash in sorted(
|
||||
model_scores[model].keys(),
|
||||
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
|
||||
)[:4]:
|
||||
data = model_scores[model][bad_md5_hash]
|
||||
if os.getenv("DEBUG", "0") == "1":
|
||||
print(md5_hash, data["score"], " ".join(data["incorrect_tags"]))
|
||||
else:
|
||||
print(data["score"], end=",")
|
||||
print("...", end="")
|
||||
|
||||
for good_md5_hash in sorted(
|
||||
model_scores[model].keys(),
|
||||
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
|
||||
reverse=True,
|
||||
)[:4]:
|
||||
data = model_scores[model][good_md5_hash]
|
||||
print(data["score"], end=",")
|
||||
|
||||
print("]")
|
||||
print("most incorrect tags", incorrect_tags_counters[model].most_common(5))
|
||||
|
||||
|
||||
async def realmain(ctx):
|
||||
await ctx.db.executescript(
|
||||
|
|
Loading…
Reference in a new issue