tagger-showdown/main.py

424 lines
12 KiB
Python
Raw Normal View History

2023-06-10 04:30:40 +00:00
import os
2023-06-10 02:19:15 +00:00
import json
import aiohttp
2023-06-10 03:36:45 +00:00
import pprint
import decimal
2023-06-10 02:53:04 +00:00
import aiofiles
2023-06-10 02:19:15 +00:00
import sys
import asyncio
import aiosqlite
2023-06-10 02:53:04 +00:00
import base64
2023-06-10 02:19:15 +00:00
import logging
2023-06-10 04:30:26 +00:00
from collections import defaultdict, Counter
2023-06-10 02:19:15 +00:00
from pathlib import Path
from urllib.parse import urlparse
2023-06-10 04:30:26 +00:00
from typing import Any, Optional, List, Set, Tuple
2023-06-10 02:19:15 +00:00
from aiolimiter import AsyncLimiter
from dataclasses import dataclass, field
log = logging.getLogger(__name__)
DEFAULTS = [
"wd14-vit-v2-git",
"wd14-vit",
"wd14-swinv2-v2-git",
"wd14-convnextv2-v2-git",
"wd14-convnext-v2-git",
"wd14-convnext",
]
2023-06-10 02:53:04 +00:00
@dataclass
class Interrogator:
model_id: str
address: str
2023-06-10 04:30:26 +00:00
def _process(self, lst):
return lst
async def fetch_tags(self, ctx, md5_hash):
rows = await ctx.db.execute_fetchall(
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
(md5_hash, self.model_id),
)
assert len(rows) == 1 # run 'fight' mode, there's missing posts
tag_string = rows[0][0]
return self._process(tag_string.split())
2023-06-10 02:53:04 +00:00
class DDInterrogator(Interrogator):
2023-06-10 04:30:26 +00:00
def _process(self, lst):
new_lst = []
for tag in lst:
if tag.startswith("rating:"):
original_danbooru_tag = tag.split(":")[1]
else:
original_danbooru_tag = tag
if original_danbooru_tag == "safe":
continue
new_lst.append(original_danbooru_tag)
return new_lst
2023-06-10 02:53:04 +00:00
async def interrogate(self, session, path):
async with session.post(
f"{self.address}/",
params={
"threshold": "0.7",
},
headers={"Authorization": "Bearer sex"},
data={"file": path.open("rb")},
) as resp:
assert resp.status == 200
tags = await resp.json()
upstream_tags = [tag.replace(" ", "_") for tag in tags]
return " ".join(upstream_tags)
class SDInterrogator(Interrogator):
async def interrogate(self, session, path):
async with aiofiles.open(path, "rb") as fd:
as_base64 = base64.b64encode(await fd.read()).decode("utf-8")
async with session.post(
f"{self.address}/tagger/v1/interrogate",
json={
"model": self.model_id,
"threshold": 0.7,
"image": as_base64,
},
) as resp:
log.info("got %d", resp.status)
assert resp.status == 200
data = await resp.json()
tags_with_scores = data["caption"]
tags = list(tags_with_scores.keys())
upstream_tags = [tag.replace(" ", "_") for tag in tags]
return " ".join(upstream_tags)
2023-06-10 02:19:15 +00:00
@dataclass
class Config:
sd_webui_address: str
dd_address: str
dd_model_name: str
sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS))
2023-06-10 02:53:04 +00:00
@property
def all_available_models(self) -> List[Any]:
return [
SDInterrogator(sd_interrogator, self.sd_webui_address)
for sd_interrogator in self.sd_webui_models
] + [DDInterrogator(self.dd_model_name, self.dd_address)]
2023-06-10 02:19:15 +00:00
@dataclass
class Context:
db: Any
session: aiohttp.ClientSession
config: Config
@dataclass
class Booru:
session: aiohttp.ClientSession
limiter: AsyncLimiter
tag_limiter: AsyncLimiter
file_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock))
tag_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock))
fetch_type = "hash"
@property
def hash_style(self):
return HashStyle.md5
class Danbooru(Booru):
title = "Danbooru"
base_url = "https://danbooru.donmai.us"
async def posts(self, tag_query: str, limit):
log.info("%s: submit %r", self.title, tag_query)
async with self.limiter:
log.info("%s: submit upstream %r", self.title, tag_query)
async with self.session.get(
f"{self.base_url}/posts.json",
params={"tags": tag_query, "limit": limit},
) as resp:
assert resp.status == 200
rjson = await resp.json()
return rjson
DOWNLOADS = Path.cwd() / "posts"
DOWNLOADS.mkdir(exist_ok=True)
2023-06-10 03:36:45 +00:00
ALLOWED_EXTENSIONS = (
"jpg",
"jpeg",
"png",
)
2023-06-10 02:19:15 +00:00
async def download_images(ctx):
try:
tagquery = sys.argv[2]
except IndexError:
tagquery = ""
try:
limit = int(sys.argv[3])
except IndexError:
limit = 10
danbooru = Danbooru(
ctx.session,
AsyncLimiter(1, 10),
AsyncLimiter(1, 3),
)
posts = await danbooru.posts(tagquery, limit)
for post in posts:
if "md5" not in post:
continue
2023-06-10 03:36:45 +00:00
md5 = post["md5"]
log.info("processing post %r", md5)
existing_post = await fetch_post(ctx, md5)
2023-06-10 02:19:15 +00:00
if existing_post:
2023-06-10 04:30:40 +00:00
log.info("already exists %r", md5)
2023-06-10 02:19:15 +00:00
continue
# download the post
post_file_url = post["file_url"]
2023-06-10 03:36:45 +00:00
post_url_path = Path(urlparse(post_file_url).path)
good_extension = False
for extension in ALLOWED_EXTENSIONS:
2023-06-10 04:30:40 +00:00
if extension in post_url_path.suffix:
2023-06-10 03:36:45 +00:00
good_extension = True
if not good_extension:
2023-06-10 04:30:40 +00:00
log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.suffix)
2023-06-10 03:36:45 +00:00
continue
post_filename = post_url_path.name
2023-06-10 02:19:15 +00:00
post_filepath = DOWNLOADS / post_filename
if not post_filepath.exists():
log.info("downloading %r to %r", post_file_url, post_filepath)
async with ctx.session.get(post_file_url) as resp:
assert resp.status == 200
with post_filepath.open("wb") as fd:
async for chunk in resp.content.iter_chunked(1024):
fd.write(chunk)
# when it's done successfully, insert it
await insert_post(ctx, post)
async def fetch_post(ctx, md5) -> Optional[dict]:
rows = await ctx.db.execute_fetchall("select data from posts where md5 = ?", (md5,))
if not rows:
return None
assert len(rows) == 1
2023-06-10 03:50:38 +00:00
post = json.loads(rows[0][0])
post_rating = post["rating"]
match post_rating:
case "g":
rating_tag = "general"
case "s":
rating_tag = "sensitive"
case "q":
rating_tag = "questionable"
case "e":
rating_tag = "explicit"
case _:
raise AssertionError("invalid post rating {post_rating!r}")
post["tag_string"] = post["tag_string"] + " " + rating_tag
return post
2023-06-10 02:19:15 +00:00
async def insert_post(ctx, post):
await ctx.db.execute_insert(
"insert into posts (md5, filepath ,data) values (?, ?,?)",
(
post["md5"],
Path(urlparse(post["file_url"]).path).name,
json.dumps(post),
),
)
await ctx.db.commit()
2023-06-10 02:53:04 +00:00
async def insert_interrogated_result(
ctx, interrogator: Interrogator, md5: str, tag_string: str
):
await ctx.db.execute_insert(
"insert into interrogated_posts (md5, model_name, output_tag_string) values (?,?,?)",
(md5, interrogator.model_id, tag_string),
)
await ctx.db.commit()
2023-06-10 02:19:15 +00:00
async def fight(ctx):
2023-06-10 02:53:04 +00:00
interrogators = ctx.config.all_available_models
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
all_hashes = set(r[0] for r in all_rows)
for interrogator in interrogators:
log.info("processing for %r", interrogator)
# calculate set of images we didn't interrogate yet
interrogated_rows = await ctx.db.execute_fetchall(
"select md5 from interrogated_posts where model_name = ?",
(interrogator.model_id,),
)
interrogated_hashes = set(row[0] for row in interrogated_rows)
missing_hashes = all_hashes - interrogated_hashes
log.info("missing %d hashes", len(missing_hashes))
2023-06-10 03:36:45 +00:00
for index, missing_hash in enumerate(missing_hashes):
log.info(
"interrogating %r (%d/%d)", missing_hash, index, len(missing_hashes)
)
2023-06-10 02:53:04 +00:00
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
tag_string = await interrogator.interrogate(ctx.session, post_filepath)
log.info("got %r", tag_string)
await insert_interrogated_result(
ctx, interrogator, missing_hash, tag_string
)
2023-06-10 02:19:15 +00:00
def score(
danbooru_tags: Set[str], interrogator_tags: Set[str]
) -> Tuple[decimal.Decimal, Set[str]]:
2023-06-10 03:50:48 +00:00
tags_in_danbooru = danbooru_tags.intersection(interrogator_tags)
tags_not_in_danbooru = interrogator_tags - danbooru_tags
return (
round(
decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru))
/ decimal.Decimal(len(danbooru_tags)),
10,
),
tags_not_in_danbooru,
)
2023-06-10 03:36:45 +00:00
async def scores(ctx):
interrogators = ctx.config.all_available_models
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
all_hashes = set(r[0] for r in all_rows)
# absolute_scores = defaultdict(decimal.Decimal)
model_scores = defaultdict(dict)
incorrect_tags_counters = defaultdict(Counter)
2023-06-10 03:36:45 +00:00
for md5_hash in all_hashes:
log.info("processing for %r", md5_hash)
post = await fetch_post(ctx, md5_hash)
danbooru_tags = set(post["tag_string"].split())
for interrogator in interrogators:
interrogator_tags = set(await interrogator.fetch_tags(ctx, md5_hash))
tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags)
for tag in incorrect_tags:
incorrect_tags_counters[interrogator.model_id][tag] += 1
model_scores[interrogator.model_id][md5_hash] = {
"score": tagging_score,
"incorrect_tags": incorrect_tags,
}
summed_scores = {
model_id: sum(d["score"] for d in post_scores.values())
for model_id, post_scores in model_scores.items()
}
2023-06-10 03:36:45 +00:00
normalized_scores = {
2023-06-10 04:32:54 +00:00
model: round(summed_scores[model] / len(all_hashes), 10)
for model in summed_scores
2023-06-10 03:36:45 +00:00
}
print("scores are [worst...best]")
2023-06-10 03:36:45 +00:00
for model in sorted(
normalized_scores.keys(),
key=lambda model: normalized_scores[model],
reverse=True,
):
print(model, normalized_scores[model])
2023-06-10 04:32:54 +00:00
if os.getenv("SHOWOFF", "0") == "1":
print("[", end="")
for bad_md5_hash in sorted(
model_scores[model].keys(),
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
)[:4]:
data = model_scores[model][bad_md5_hash]
if os.getenv("DEBUG", "0") == "1":
print(md5_hash, data["score"], " ".join(data["incorrect_tags"]))
else:
print(data["score"], end=",")
print("...", end="")
for good_md5_hash in sorted(
model_scores[model].keys(),
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
reverse=True,
)[:4]:
data = model_scores[model][good_md5_hash]
print(data["score"], end=",")
2023-06-10 04:32:54 +00:00
print("]")
print("most incorrect tags", incorrect_tags_counters[model].most_common(5))
2023-06-10 03:36:45 +00:00
2023-06-10 02:19:15 +00:00
async def realmain(ctx):
await ctx.db.executescript(
"""
2023-06-10 02:53:04 +00:00
create table if not exists posts (
md5 text primary key,
filepath text,
data text
);
create table if not exists interrogated_posts (
md5 text,
model_name text not null,
output_tag_string text not null,
primary key (md5, model_name)
);
2023-06-10 02:19:15 +00:00
"""
)
try:
mode = sys.argv[1]
except IndexError:
raise Exception("must have mode")
if mode == "download_images":
await download_images(ctx)
elif mode == "fight":
await fight(ctx)
2023-06-10 02:53:04 +00:00
elif mode == "scores":
await scores(ctx)
2023-06-10 02:19:15 +00:00
else:
raise AssertionError(f"invalid mode {mode}")
async def main():
CONFIG_PATH = Path.cwd() / "config.json"
log.info("hewo")
async with aiosqlite.connect("./data.db") as db:
async with aiohttp.ClientSession() as session:
with CONFIG_PATH.open() as config_fd:
config_json = json.load(config_fd)
config = Config(**config_json)
ctx = Context(db, session, config)
await realmain(ctx)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(main())