Compare commits
3 commits
18da5d7972
...
099c7f8fc9
Author | SHA1 | Date | |
---|---|---|---|
099c7f8fc9 | |||
f3ec3ab6d3 | |||
919bd4017e |
1 changed files with 85 additions and 20 deletions
105
main.py
105
main.py
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
import json
|
import json
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import pprint
|
import pprint
|
||||||
|
@ -8,10 +9,10 @@ import asyncio
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
from collections import defaultdict, Counter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from typing import Any, Optional, List, Set
|
from typing import Any, Optional, List, Set, Tuple
|
||||||
from aiolimiter import AsyncLimiter
|
from aiolimiter import AsyncLimiter
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
@ -33,8 +34,32 @@ class Interrogator:
|
||||||
model_id: str
|
model_id: str
|
||||||
address: str
|
address: str
|
||||||
|
|
||||||
|
def _process(self, lst):
|
||||||
|
return lst
|
||||||
|
|
||||||
|
async def fetch_tags(self, ctx, md5_hash):
|
||||||
|
rows = await ctx.db.execute_fetchall(
|
||||||
|
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
|
||||||
|
(md5_hash, self.model_id),
|
||||||
|
)
|
||||||
|
assert len(rows) == 1 # run 'fight' mode, there's missing posts
|
||||||
|
tag_string = rows[0][0]
|
||||||
|
return self._process(tag_string.split())
|
||||||
|
|
||||||
|
|
||||||
class DDInterrogator(Interrogator):
|
class DDInterrogator(Interrogator):
|
||||||
|
def _process(self, lst):
|
||||||
|
new_lst = []
|
||||||
|
for tag in lst:
|
||||||
|
if tag.startswith("rating:"):
|
||||||
|
original_danbooru_tag = tag.split(":")[1]
|
||||||
|
else:
|
||||||
|
original_danbooru_tag = tag
|
||||||
|
if original_danbooru_tag == "safe":
|
||||||
|
continue
|
||||||
|
new_lst.append(original_danbooru_tag)
|
||||||
|
return new_lst
|
||||||
|
|
||||||
async def interrogate(self, session, path):
|
async def interrogate(self, session, path):
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{self.address}/",
|
f"{self.address}/",
|
||||||
|
@ -163,6 +188,7 @@ async def download_images(ctx):
|
||||||
log.info("processing post %r", md5)
|
log.info("processing post %r", md5)
|
||||||
existing_post = await fetch_post(ctx, md5)
|
existing_post = await fetch_post(ctx, md5)
|
||||||
if existing_post:
|
if existing_post:
|
||||||
|
log.info("already exists %r", md5)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# download the post
|
# download the post
|
||||||
|
@ -170,10 +196,10 @@ async def download_images(ctx):
|
||||||
post_url_path = Path(urlparse(post_file_url).path)
|
post_url_path = Path(urlparse(post_file_url).path)
|
||||||
good_extension = False
|
good_extension = False
|
||||||
for extension in ALLOWED_EXTENSIONS:
|
for extension in ALLOWED_EXTENSIONS:
|
||||||
if extension in post_url_path.stem:
|
if extension in post_url_path.suffix:
|
||||||
good_extension = True
|
good_extension = True
|
||||||
if not good_extension:
|
if not good_extension:
|
||||||
log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.stem)
|
log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.suffix)
|
||||||
continue
|
continue
|
||||||
post_filename = post_url_path.name
|
post_filename = post_url_path.name
|
||||||
post_filepath = DOWNLOADS / post_filename
|
post_filepath = DOWNLOADS / post_filename
|
||||||
|
@ -263,12 +289,19 @@ async def fight(ctx):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def score(danbooru_tags: Set[str], interrogator_tags: Set[str]) -> decimal.Decimal:
|
def score(
|
||||||
|
danbooru_tags: Set[str], interrogator_tags: Set[str]
|
||||||
|
) -> Tuple[decimal.Decimal, Set[str]]:
|
||||||
tags_in_danbooru = danbooru_tags.intersection(interrogator_tags)
|
tags_in_danbooru = danbooru_tags.intersection(interrogator_tags)
|
||||||
tags_not_in_danbooru = interrogator_tags - danbooru_tags
|
tags_not_in_danbooru = interrogator_tags - danbooru_tags
|
||||||
return decimal.Decimal(
|
return (
|
||||||
len(tags_in_danbooru) - len(tags_not_in_danbooru)
|
round(
|
||||||
) / decimal.Decimal(len(danbooru_tags))
|
decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru))
|
||||||
|
/ decimal.Decimal(len(danbooru_tags)),
|
||||||
|
10,
|
||||||
|
),
|
||||||
|
tags_not_in_danbooru,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def scores(ctx):
|
async def scores(ctx):
|
||||||
|
@ -277,28 +310,36 @@ async def scores(ctx):
|
||||||
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
|
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
|
||||||
all_hashes = set(r[0] for r in all_rows)
|
all_hashes = set(r[0] for r in all_rows)
|
||||||
|
|
||||||
absolute_scores = defaultdict(decimal.Decimal)
|
# absolute_scores = defaultdict(decimal.Decimal)
|
||||||
|
model_scores = defaultdict(dict)
|
||||||
|
incorrect_tags_counters = defaultdict(Counter)
|
||||||
|
|
||||||
for md5_hash in all_hashes:
|
for md5_hash in all_hashes:
|
||||||
log.info("processing for %r", md5_hash)
|
log.info("processing for %r", md5_hash)
|
||||||
post = await fetch_post(ctx, md5_hash)
|
post = await fetch_post(ctx, md5_hash)
|
||||||
danbooru_tags = set(post["tag_string"].split())
|
danbooru_tags = set(post["tag_string"].split())
|
||||||
for interrogator in interrogators:
|
for interrogator in interrogators:
|
||||||
rows = await ctx.db.execute_fetchall(
|
interrogator_tags = set(await interrogator.fetch_tags(ctx, md5_hash))
|
||||||
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
|
tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags)
|
||||||
(md5_hash, interrogator.model_id),
|
for tag in incorrect_tags:
|
||||||
)
|
incorrect_tags_counters[interrogator.model_id][tag] += 1
|
||||||
assert len(rows) == 1 # run 'fight' mode, there's missing posts
|
|
||||||
tag_string = rows[0][0]
|
model_scores[interrogator.model_id][md5_hash] = {
|
||||||
interrogator_tags = set(tag_string.split())
|
"score": tagging_score,
|
||||||
absolute_scores[interrogator.model_id] += score(
|
"incorrect_tags": incorrect_tags,
|
||||||
danbooru_tags, interrogator_tags
|
}
|
||||||
)
|
|
||||||
|
summed_scores = {
|
||||||
|
model_id: sum(d["score"] for d in post_scores.values())
|
||||||
|
for model_id, post_scores in model_scores.items()
|
||||||
|
}
|
||||||
|
|
||||||
normalized_scores = {
|
normalized_scores = {
|
||||||
model: absolute_scores[model] / len(all_hashes) for model in absolute_scores
|
model: summed_scores[model] / len(all_hashes) for model in summed_scores
|
||||||
}
|
}
|
||||||
|
|
||||||
|
print("scores are [worst...best]")
|
||||||
|
|
||||||
for model in sorted(
|
for model in sorted(
|
||||||
normalized_scores.keys(),
|
normalized_scores.keys(),
|
||||||
key=lambda model: normalized_scores[model],
|
key=lambda model: normalized_scores[model],
|
||||||
|
@ -306,6 +347,30 @@ async def scores(ctx):
|
||||||
):
|
):
|
||||||
print(model, normalized_scores[model])
|
print(model, normalized_scores[model])
|
||||||
|
|
||||||
|
print("[", end="")
|
||||||
|
|
||||||
|
for bad_md5_hash in sorted(
|
||||||
|
model_scores[model].keys(),
|
||||||
|
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
|
||||||
|
)[:4]:
|
||||||
|
data = model_scores[model][bad_md5_hash]
|
||||||
|
if os.getenv("DEBUG", "0") == "1":
|
||||||
|
print(md5_hash, data["score"], " ".join(data["incorrect_tags"]))
|
||||||
|
else:
|
||||||
|
print(data["score"], end=",")
|
||||||
|
print("...", end="")
|
||||||
|
|
||||||
|
for good_md5_hash in sorted(
|
||||||
|
model_scores[model].keys(),
|
||||||
|
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
|
||||||
|
reverse=True,
|
||||||
|
)[:4]:
|
||||||
|
data = model_scores[model][good_md5_hash]
|
||||||
|
print(data["score"], end=",")
|
||||||
|
|
||||||
|
print("]")
|
||||||
|
print("most incorrect tags", incorrect_tags_counters[model].most_common(5))
|
||||||
|
|
||||||
|
|
||||||
async def realmain(ctx):
|
async def realmain(ctx):
|
||||||
await ctx.db.executescript(
|
await ctx.db.executescript(
|
||||||
|
|
Loading…
Reference in a new issue