From 6f3ce2ab024d08df2a67cc07e4b7656f231534f2 Mon Sep 17 00:00:00 2001 From: Luna Date: Sat, 10 Jun 2023 00:36:45 -0300 Subject: [PATCH] add scoring --- main.py | 76 ++++++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 70 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index 922421c..5ca42aa 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,7 @@ import json import aiohttp +import pprint +import decimal import aiofiles import sys import asyncio @@ -9,7 +11,7 @@ import logging from collections import defaultdict from pathlib import Path from urllib.parse import urlparse -from typing import Any, Optional, List +from typing import Any, Optional, List, Set from aiolimiter import AsyncLimiter from dataclasses import dataclass, field @@ -129,6 +131,13 @@ DOWNLOADS = Path.cwd() / "posts" DOWNLOADS.mkdir(exist_ok=True) +ALLOWED_EXTENSIONS = ( + "jpg", + "jpeg", + "png", +) + + async def download_images(ctx): try: tagquery = sys.argv[2] @@ -150,14 +159,23 @@ async def download_images(ctx): for post in posts: if "md5" not in post: continue - log.info("processing post %r", post) - existing_post = await fetch_post(ctx, post["md5"]) + md5 = post["md5"] + log.info("processing post %r", md5) + existing_post = await fetch_post(ctx, md5) if existing_post: continue # download the post post_file_url = post["file_url"] - post_filename = Path(urlparse(post_file_url).path).name + post_url_path = Path(urlparse(post_file_url).path) + good_extension = False + for extension in ALLOWED_EXTENSIONS: + if extension in post_url_path.stem: + good_extension = True + if not good_extension: + log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.stem) + continue + post_filename = post_url_path.name post_filepath = DOWNLOADS / post_filename if not post_filepath.exists(): log.info("downloading %r to %r", post_file_url, post_filepath) @@ -219,8 +237,10 @@ async def fight(ctx): log.info("missing %d hashes", len(missing_hashes)) - for missing_hash in missing_hashes: - log.info("interrogating %r", missing_hash) + for index, missing_hash in enumerate(missing_hashes): + log.info( + "interrogating %r (%d/%d)", missing_hash, index, len(missing_hashes) + ) post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*")) tag_string = await interrogator.interrogate(ctx.session, post_filepath) log.info("got %r", tag_string) @@ -229,6 +249,50 @@ async def fight(ctx): ) +def score(danbooru_tags: Set[str], interrogator_tags: Set[str]) -> decimal.Decimal: + tags_in_danbooru = danbooru_tags | interrogator_tags + tags_not_in_danbooru = danbooru_tags - interrogator_tags + return decimal.Decimal( + len(tags_in_danbooru) - len(tags_not_in_danbooru) + ) / decimal.Decimal(len(danbooru_tags)) + + +async def scores(ctx): + interrogators = ctx.config.all_available_models + + all_rows = await ctx.db.execute_fetchall("select md5 from posts") + all_hashes = set(r[0] for r in all_rows) + + absolute_scores = defaultdict(decimal.Decimal) + + for md5_hash in all_hashes: + log.info("processing for %r", md5_hash) + post = await fetch_post(ctx, md5_hash) + danbooru_tags = set(post["tag_string"].split()) + for interrogator in interrogators: + rows = await ctx.db.execute_fetchall( + "select output_tag_string from interrogated_posts where md5 = ? and model_name = ?", + (md5_hash, interrogator.model_id), + ) + assert len(rows) == 1 # run 'fight' mode, there's missing posts + tag_string = rows[0][0] + interrogator_tags = set(tag_string.split()) + absolute_scores[interrogator.model_id] += score( + danbooru_tags, interrogator_tags + ) + + normalized_scores = { + model: absolute_scores[model] / len(all_hashes) for model in absolute_scores + } + + for model in sorted( + normalized_scores.keys(), + key=lambda model: normalized_scores[model], + reverse=True, + ): + print(model, normalized_scores[model]) + + async def realmain(ctx): await ctx.db.executescript( """