add scoring
This commit is contained in:
parent
dbf458eed7
commit
6f3ce2ab02
1 changed files with 70 additions and 6 deletions
76
main.py
76
main.py
|
@ -1,5 +1,7 @@
|
|||
import json
|
||||
import aiohttp
|
||||
import pprint
|
||||
import decimal
|
||||
import aiofiles
|
||||
import sys
|
||||
import asyncio
|
||||
|
@ -9,7 +11,7 @@ import logging
|
|||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
from typing import Any, Optional, List
|
||||
from typing import Any, Optional, List, Set
|
||||
from aiolimiter import AsyncLimiter
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
@ -129,6 +131,13 @@ DOWNLOADS = Path.cwd() / "posts"
|
|||
DOWNLOADS.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
ALLOWED_EXTENSIONS = (
|
||||
"jpg",
|
||||
"jpeg",
|
||||
"png",
|
||||
)
|
||||
|
||||
|
||||
async def download_images(ctx):
|
||||
try:
|
||||
tagquery = sys.argv[2]
|
||||
|
@ -150,14 +159,23 @@ async def download_images(ctx):
|
|||
for post in posts:
|
||||
if "md5" not in post:
|
||||
continue
|
||||
log.info("processing post %r", post)
|
||||
existing_post = await fetch_post(ctx, post["md5"])
|
||||
md5 = post["md5"]
|
||||
log.info("processing post %r", md5)
|
||||
existing_post = await fetch_post(ctx, md5)
|
||||
if existing_post:
|
||||
continue
|
||||
|
||||
# download the post
|
||||
post_file_url = post["file_url"]
|
||||
post_filename = Path(urlparse(post_file_url).path).name
|
||||
post_url_path = Path(urlparse(post_file_url).path)
|
||||
good_extension = False
|
||||
for extension in ALLOWED_EXTENSIONS:
|
||||
if extension in post_url_path.stem:
|
||||
good_extension = True
|
||||
if not good_extension:
|
||||
log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.stem)
|
||||
continue
|
||||
post_filename = post_url_path.name
|
||||
post_filepath = DOWNLOADS / post_filename
|
||||
if not post_filepath.exists():
|
||||
log.info("downloading %r to %r", post_file_url, post_filepath)
|
||||
|
@ -219,8 +237,10 @@ async def fight(ctx):
|
|||
|
||||
log.info("missing %d hashes", len(missing_hashes))
|
||||
|
||||
for missing_hash in missing_hashes:
|
||||
log.info("interrogating %r", missing_hash)
|
||||
for index, missing_hash in enumerate(missing_hashes):
|
||||
log.info(
|
||||
"interrogating %r (%d/%d)", missing_hash, index, len(missing_hashes)
|
||||
)
|
||||
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
|
||||
tag_string = await interrogator.interrogate(ctx.session, post_filepath)
|
||||
log.info("got %r", tag_string)
|
||||
|
@ -229,6 +249,50 @@ async def fight(ctx):
|
|||
)
|
||||
|
||||
|
||||
def score(danbooru_tags: Set[str], interrogator_tags: Set[str]) -> decimal.Decimal:
|
||||
tags_in_danbooru = danbooru_tags | interrogator_tags
|
||||
tags_not_in_danbooru = danbooru_tags - interrogator_tags
|
||||
return decimal.Decimal(
|
||||
len(tags_in_danbooru) - len(tags_not_in_danbooru)
|
||||
) / decimal.Decimal(len(danbooru_tags))
|
||||
|
||||
|
||||
async def scores(ctx):
|
||||
interrogators = ctx.config.all_available_models
|
||||
|
||||
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
|
||||
all_hashes = set(r[0] for r in all_rows)
|
||||
|
||||
absolute_scores = defaultdict(decimal.Decimal)
|
||||
|
||||
for md5_hash in all_hashes:
|
||||
log.info("processing for %r", md5_hash)
|
||||
post = await fetch_post(ctx, md5_hash)
|
||||
danbooru_tags = set(post["tag_string"].split())
|
||||
for interrogator in interrogators:
|
||||
rows = await ctx.db.execute_fetchall(
|
||||
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
|
||||
(md5_hash, interrogator.model_id),
|
||||
)
|
||||
assert len(rows) == 1 # run 'fight' mode, there's missing posts
|
||||
tag_string = rows[0][0]
|
||||
interrogator_tags = set(tag_string.split())
|
||||
absolute_scores[interrogator.model_id] += score(
|
||||
danbooru_tags, interrogator_tags
|
||||
)
|
||||
|
||||
normalized_scores = {
|
||||
model: absolute_scores[model] / len(all_hashes) for model in absolute_scores
|
||||
}
|
||||
|
||||
for model in sorted(
|
||||
normalized_scores.keys(),
|
||||
key=lambda model: normalized_scores[model],
|
||||
reverse=True,
|
||||
):
|
||||
print(model, normalized_scores[model])
|
||||
|
||||
|
||||
async def realmain(ctx):
|
||||
await ctx.db.executescript(
|
||||
"""
|
||||
|
|
Loading…
Reference in a new issue