diff --git a/main.py b/main.py index ed9ca39..40a6b06 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,3 @@ -import os import json import aiohttp import pprint @@ -9,10 +8,10 @@ import asyncio import aiosqlite import base64 import logging -from collections import defaultdict, Counter +from collections import defaultdict from pathlib import Path from urllib.parse import urlparse -from typing import Any, Optional, List, Set, Tuple +from typing import Any, Optional, List, Set from aiolimiter import AsyncLimiter from dataclasses import dataclass, field @@ -34,32 +33,8 @@ class Interrogator: model_id: str address: str - def _process(self, lst): - return lst - - async def fetch_tags(self, ctx, md5_hash): - rows = await ctx.db.execute_fetchall( - "select output_tag_string from interrogated_posts where md5 = ? and model_name = ?", - (md5_hash, self.model_id), - ) - assert len(rows) == 1 # run 'fight' mode, there's missing posts - tag_string = rows[0][0] - return self._process(tag_string.split()) - class DDInterrogator(Interrogator): - def _process(self, lst): - new_lst = [] - for tag in lst: - if tag.startswith("rating:"): - original_danbooru_tag = tag.split(":")[1] - else: - original_danbooru_tag = tag - if original_danbooru_tag == "safe": - continue - new_lst.append(original_danbooru_tag) - return new_lst - async def interrogate(self, session, path): async with session.post( f"{self.address}/", @@ -188,7 +163,6 @@ async def download_images(ctx): log.info("processing post %r", md5) existing_post = await fetch_post(ctx, md5) if existing_post: - log.info("already exists %r", md5) continue # download the post @@ -196,10 +170,10 @@ async def download_images(ctx): post_url_path = Path(urlparse(post_file_url).path) good_extension = False for extension in ALLOWED_EXTENSIONS: - if extension in post_url_path.suffix: + if extension in post_url_path.stem: good_extension = True if not good_extension: - log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.suffix) + log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.stem) continue post_filename = post_url_path.name post_filepath = DOWNLOADS / post_filename @@ -289,19 +263,12 @@ async def fight(ctx): ) -def score( - danbooru_tags: Set[str], interrogator_tags: Set[str] -) -> Tuple[decimal.Decimal, Set[str]]: +def score(danbooru_tags: Set[str], interrogator_tags: Set[str]) -> decimal.Decimal: tags_in_danbooru = danbooru_tags.intersection(interrogator_tags) tags_not_in_danbooru = interrogator_tags - danbooru_tags - return ( - round( - decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru)) - / decimal.Decimal(len(danbooru_tags)), - 10, - ), - tags_not_in_danbooru, - ) + return decimal.Decimal( + len(tags_in_danbooru) - len(tags_not_in_danbooru) + ) / decimal.Decimal(len(danbooru_tags)) async def scores(ctx): @@ -310,36 +277,28 @@ async def scores(ctx): all_rows = await ctx.db.execute_fetchall("select md5 from posts") all_hashes = set(r[0] for r in all_rows) - # absolute_scores = defaultdict(decimal.Decimal) - model_scores = defaultdict(dict) - incorrect_tags_counters = defaultdict(Counter) + absolute_scores = defaultdict(decimal.Decimal) for md5_hash in all_hashes: log.info("processing for %r", md5_hash) post = await fetch_post(ctx, md5_hash) danbooru_tags = set(post["tag_string"].split()) for interrogator in interrogators: - interrogator_tags = set(await interrogator.fetch_tags(ctx, md5_hash)) - tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags) - for tag in incorrect_tags: - incorrect_tags_counters[interrogator.model_id][tag] += 1 - - model_scores[interrogator.model_id][md5_hash] = { - "score": tagging_score, - "incorrect_tags": incorrect_tags, - } - - summed_scores = { - model_id: sum(d["score"] for d in post_scores.values()) - for model_id, post_scores in model_scores.items() - } + rows = await ctx.db.execute_fetchall( + "select output_tag_string from interrogated_posts where md5 = ? and model_name = ?", + (md5_hash, interrogator.model_id), + ) + assert len(rows) == 1 # run 'fight' mode, there's missing posts + tag_string = rows[0][0] + interrogator_tags = set(tag_string.split()) + absolute_scores[interrogator.model_id] += score( + danbooru_tags, interrogator_tags + ) normalized_scores = { - model: summed_scores[model] / len(all_hashes) for model in summed_scores + model: absolute_scores[model] / len(all_hashes) for model in absolute_scores } - print("scores are [worst...best]") - for model in sorted( normalized_scores.keys(), key=lambda model: normalized_scores[model], @@ -347,30 +306,6 @@ async def scores(ctx): ): print(model, normalized_scores[model]) - print("[", end="") - - for bad_md5_hash in sorted( - model_scores[model].keys(), - key=lambda md5_hash: model_scores[model][md5_hash]["score"], - )[:4]: - data = model_scores[model][bad_md5_hash] - if os.getenv("DEBUG", "0") == "1": - print(md5_hash, data["score"], " ".join(data["incorrect_tags"])) - else: - print(data["score"], end=",") - print("...", end="") - - for good_md5_hash in sorted( - model_scores[model].keys(), - key=lambda md5_hash: model_scores[model][md5_hash]["score"], - reverse=True, - )[:4]: - data = model_scores[model][good_md5_hash] - print(data["score"], end=",") - - print("]") - print("most incorrect tags", incorrect_tags_counters[model].most_common(5)) - async def realmain(ctx): await ctx.db.executescript(