add scoring

This commit is contained in:
Luna 2023-06-10 00:36:45 -03:00
parent dbf458eed7
commit 6f3ce2ab02

76
main.py
View file

@ -1,5 +1,7 @@
import json
import aiohttp
import pprint
import decimal
import aiofiles
import sys
import asyncio
@ -9,7 +11,7 @@ import logging
from collections import defaultdict
from pathlib import Path
from urllib.parse import urlparse
from typing import Any, Optional, List
from typing import Any, Optional, List, Set
from aiolimiter import AsyncLimiter
from dataclasses import dataclass, field
@ -129,6 +131,13 @@ DOWNLOADS = Path.cwd() / "posts"
DOWNLOADS.mkdir(exist_ok=True)
ALLOWED_EXTENSIONS = (
"jpg",
"jpeg",
"png",
)
async def download_images(ctx):
try:
tagquery = sys.argv[2]
@ -150,14 +159,23 @@ async def download_images(ctx):
for post in posts:
if "md5" not in post:
continue
log.info("processing post %r", post)
existing_post = await fetch_post(ctx, post["md5"])
md5 = post["md5"]
log.info("processing post %r", md5)
existing_post = await fetch_post(ctx, md5)
if existing_post:
continue
# download the post
post_file_url = post["file_url"]
post_filename = Path(urlparse(post_file_url).path).name
post_url_path = Path(urlparse(post_file_url).path)
good_extension = False
for extension in ALLOWED_EXTENSIONS:
if extension in post_url_path.stem:
good_extension = True
if not good_extension:
log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.stem)
continue
post_filename = post_url_path.name
post_filepath = DOWNLOADS / post_filename
if not post_filepath.exists():
log.info("downloading %r to %r", post_file_url, post_filepath)
@ -219,8 +237,10 @@ async def fight(ctx):
log.info("missing %d hashes", len(missing_hashes))
for missing_hash in missing_hashes:
log.info("interrogating %r", missing_hash)
for index, missing_hash in enumerate(missing_hashes):
log.info(
"interrogating %r (%d/%d)", missing_hash, index, len(missing_hashes)
)
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
tag_string = await interrogator.interrogate(ctx.session, post_filepath)
log.info("got %r", tag_string)
@ -229,6 +249,50 @@ async def fight(ctx):
)
def score(danbooru_tags: Set[str], interrogator_tags: Set[str]) -> decimal.Decimal:
tags_in_danbooru = danbooru_tags | interrogator_tags
tags_not_in_danbooru = danbooru_tags - interrogator_tags
return decimal.Decimal(
len(tags_in_danbooru) - len(tags_not_in_danbooru)
) / decimal.Decimal(len(danbooru_tags))
async def scores(ctx):
interrogators = ctx.config.all_available_models
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
all_hashes = set(r[0] for r in all_rows)
absolute_scores = defaultdict(decimal.Decimal)
for md5_hash in all_hashes:
log.info("processing for %r", md5_hash)
post = await fetch_post(ctx, md5_hash)
danbooru_tags = set(post["tag_string"].split())
for interrogator in interrogators:
rows = await ctx.db.execute_fetchall(
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
(md5_hash, interrogator.model_id),
)
assert len(rows) == 1 # run 'fight' mode, there's missing posts
tag_string = rows[0][0]
interrogator_tags = set(tag_string.split())
absolute_scores[interrogator.model_id] += score(
danbooru_tags, interrogator_tags
)
normalized_scores = {
model: absolute_scores[model] / len(all_hashes) for model in absolute_scores
}
for model in sorted(
normalized_scores.keys(),
key=lambda model: normalized_scores[model],
reverse=True,
):
print(model, normalized_scores[model])
async def realmain(ctx):
await ctx.db.executescript(
"""