tagger-showdown/main.py

556 lines
17 KiB
Python

import time
import os
import json
import aiohttp
import pprint
import decimal
import aiofiles
import sys
import asyncio
import aiosqlite
import base64
import logging
import pandas as pd
import plotly.express as px
import plotly.graph_objs as go
import plotly.io as pio
from collections import defaultdict, Counter
from pathlib import Path
from urllib.parse import urlparse
from typing import Any, Optional, List, Set, Tuple
from aiolimiter import AsyncLimiter
from dataclasses import dataclass, field
log = logging.getLogger(__name__)
DEFAULTS = [
"wd14-vit-v2-git",
"wd14-vit",
"wd14-swinv2-v2-git",
"wd14-convnextv2-v2-git",
"wd14-convnext-v2-git",
"wd14-convnext",
]
@dataclass
class InterrogatorPost:
tags: List[str]
time_taken: float
@dataclass
class Interrogator:
model_id: str
address: str
def _process(self, lst):
return lst
async def fetch(self, ctx, md5_hash):
rows = await ctx.db.execute_fetchall(
"select output_tag_string, time_taken from interrogated_posts where md5 = ? and model_name = ?",
(md5_hash, self.model_id),
)
assert len(rows) == 1 # run 'fight' mode, there's missing posts
tag_string, time_taken = rows[0][0], rows[0][1]
return InterrogatorPost(self._process(tag_string.split()), time_taken)
class DDInterrogator(Interrogator):
def _process(self, lst):
new_lst = []
for tag in lst:
if tag.startswith("rating:"):
original_danbooru_tag = tag.split(":")[1]
else:
original_danbooru_tag = tag
if original_danbooru_tag == "safe":
original_danbooru_tag = "general"
new_lst.append(original_danbooru_tag)
return new_lst
async def interrogate(self, ctx, path):
async with ctx.session.post(
f"{self.address}/",
params={
"threshold": "0.7",
},
headers={"Authorization": "Bearer sex"},
data={"file": path.open("rb")},
) as resp:
assert resp.status == 200
tags = await resp.json()
upstream_tags = [tag.replace(" ", "_") for tag in tags]
return " ".join(upstream_tags)
class SDInterrogator(Interrogator):
async def interrogate(self, ctx, path):
async with aiofiles.open(path, "rb") as fd:
as_base64 = base64.b64encode(await fd.read()).decode("utf-8")
async with ctx.session.post(
f"{self.address}/tagger/v1/interrogate",
json={
"model": self.model_id,
"threshold": 0.7,
"image": as_base64,
},
) as resp:
log.info("got %d", resp.status)
assert resp.status == 200
data = await resp.json()
tags_with_scores = data["caption"]
tags = list(tags_with_scores.keys())
upstream_tags = [tag.replace(" ", "_") for tag in tags]
return " ".join(upstream_tags)
class ControlInterrogator(Interrogator):
async def interrogate(self, ctx, path):
md5_hash = Path(path).stem
post = await fetch_post(ctx, md5_hash)
return post["tag_string"]
@dataclass
class Config:
sd_webui_address: str
dd_address: str
dd_model_name: str
sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS))
@property
def all_available_models(self) -> List[Any]:
return (
[
SDInterrogator(sd_interrogator, self.sd_webui_address)
for sd_interrogator in self.sd_webui_models
]
+ [DDInterrogator(self.dd_model_name, self.dd_address)]
+ [ControlInterrogator("control", None)]
)
@dataclass
class Context:
db: Any
session: aiohttp.ClientSession
config: Config
@dataclass
class Booru:
session: aiohttp.ClientSession
limiter: AsyncLimiter
tag_limiter: AsyncLimiter
file_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock))
tag_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock))
fetch_type = "hash"
@property
def hash_style(self):
return HashStyle.md5
class Danbooru(Booru):
title = "Danbooru"
base_url = "https://danbooru.donmai.us"
async def posts(self, tag_query: str, limit, page: int):
log.info("%s: submit %r", self.title, tag_query)
async with self.limiter:
log.info("%s: submit upstream %r", self.title, tag_query)
async with self.session.get(
f"{self.base_url}/posts.json",
params={"tags": tag_query, "limit": limit, "page": page},
) as resp:
assert resp.status == 200
rjson = await resp.json()
return rjson
DOWNLOADS = Path.cwd() / "posts"
DOWNLOADS.mkdir(exist_ok=True)
ALLOWED_EXTENSIONS = (
"jpg",
"jpeg",
"png",
)
async def download_images(ctx):
try:
tagquery = sys.argv[2]
except IndexError:
tagquery = ""
try:
limit = int(sys.argv[3])
except IndexError:
limit = 30
try:
pageskip = int(sys.argv[4])
except IndexError:
pageskip = 150
danbooru = Danbooru(
ctx.session,
AsyncLimiter(1, 10),
AsyncLimiter(1, 3),
)
posts = await danbooru.posts(tagquery, limit, pageskip)
for post in posts:
if "md5" not in post:
continue
md5 = post["md5"]
log.info("processing post %r", md5)
existing_post = await fetch_post(ctx, md5)
if existing_post:
log.info("already exists %r", md5)
continue
# download the post
post_file_url = post["file_url"]
post_url_path = Path(urlparse(post_file_url).path)
good_extension = False
for extension in ALLOWED_EXTENSIONS:
if extension in post_url_path.suffix:
good_extension = True
if not good_extension:
log.info("ignoring %r, invalid extension (%r)", md5, post_url_path.suffix)
continue
post_filename = post_url_path.name
post_filepath = DOWNLOADS / post_filename
if not post_filepath.exists():
log.info("downloading %r to %r", post_file_url, post_filepath)
async with ctx.session.get(post_file_url) as resp:
assert resp.status == 200
with post_filepath.open("wb") as fd:
async for chunk in resp.content.iter_chunked(1024):
fd.write(chunk)
# when it's done successfully, insert it
await insert_post(ctx, post)
async def fetch_post(ctx, md5) -> Optional[dict]:
rows = await ctx.db.execute_fetchall("select data from posts where md5 = ?", (md5,))
if not rows:
return None
assert len(rows) == 1
post = json.loads(rows[0][0])
post_rating = post["rating"]
match post_rating:
case "g":
rating_tag = "general"
case "s":
rating_tag = "sensitive"
case "q":
rating_tag = "questionable"
case "e":
rating_tag = "explicit"
case _:
raise AssertionError("invalid post rating {post_rating!r}")
post["tag_string"] = post["tag_string"] + " " + rating_tag
return post
async def insert_post(ctx, post):
await ctx.db.execute_insert(
"insert into posts (md5, filepath ,data) values (?, ?,?)",
(
post["md5"],
Path(urlparse(post["file_url"]).path).name,
json.dumps(post),
),
)
await ctx.db.commit()
async def insert_interrogated_result(
ctx, interrogator: Interrogator, md5: str, tag_string: str, time_taken: float
):
await ctx.db.execute_insert(
"insert into interrogated_posts (md5, model_name, output_tag_string, time_taken) values (?,?,?,?)",
(md5, interrogator.model_id, tag_string, time_taken),
)
await ctx.db.commit()
async def fight(ctx):
interrogators = ctx.config.all_available_models
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
all_hashes = set(r[0] for r in all_rows)
for interrogator in interrogators:
log.info("processing for %r", interrogator)
# calculate set of images we didn't interrogate yet
interrogated_rows = await ctx.db.execute_fetchall(
"select md5 from interrogated_posts where model_name = ?",
(interrogator.model_id,),
)
interrogated_hashes = set(row[0] for row in interrogated_rows)
missing_hashes = all_hashes - interrogated_hashes
log.info("missing %d hashes", len(missing_hashes))
for index, missing_hash in enumerate(missing_hashes):
log.info(
"interrogating %r (%d/%d)", missing_hash, index, len(missing_hashes)
)
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
start_ts = time.monotonic()
tag_string = await interrogator.interrogate(ctx, post_filepath)
end_ts = time.monotonic()
time_taken = round(end_ts - start_ts, 10)
log.info("took %.5fsec, got %r", time_taken, tag_string)
await insert_interrogated_result(
ctx, interrogator, missing_hash, tag_string, time_taken
)
def score(
danbooru_tags: Set[str], interrogator_tags: Set[str]
) -> Tuple[decimal.Decimal, Set[str]]:
tags_in_danbooru = danbooru_tags.intersection(interrogator_tags)
tags_not_in_danbooru = interrogator_tags - danbooru_tags
return (
round(
decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru))
/ decimal.Decimal(len(danbooru_tags)),
10,
),
tags_not_in_danbooru,
)
async def scores(ctx):
interrogators = ctx.config.all_available_models
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
all_hashes = set(r[0] for r in all_rows)
# absolute_scores = defaultdict(decimal.Decimal)
model_scores = defaultdict(dict)
runtimes = defaultdict(list)
incorrect_tags_counters = defaultdict(Counter)
for md5_hash in all_hashes:
log.info("processing for %r", md5_hash)
post = await fetch_post(ctx, md5_hash)
danbooru_tags = set(post["tag_string"].split())
for interrogator in interrogators:
post_data = await interrogator.fetch(ctx, md5_hash)
runtimes[interrogator.model_id].append(post_data.time_taken)
interrogator_tags = set(post_data.tags)
tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags)
for tag in incorrect_tags:
incorrect_tags_counters[interrogator.model_id][tag] += 1
model_scores[interrogator.model_id][md5_hash] = {
"score": tagging_score,
"incorrect_tags": incorrect_tags,
}
summed_scores = {
model_id: sum(d["score"] for d in post_scores.values())
for model_id, post_scores in model_scores.items()
}
normalized_scores = {
model: round(summed_scores[model] / len(all_hashes), 10)
for model in summed_scores
}
print("scores are [worst...best]")
for model in sorted(
normalized_scores.keys(),
key=lambda model: normalized_scores[model],
reverse=True,
):
average_runtime = sum(runtimes[model]) / len(runtimes[model])
print(model, normalized_scores[model], "runtime", average_runtime, "sec")
if os.getenv("SHOWOFF", "0") == "1":
print("[", end="")
for bad_md5_hash in sorted(
model_scores[model].keys(),
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
)[:4]:
data = model_scores[model][bad_md5_hash]
if os.getenv("DEBUG", "0") == "1":
print(md5_hash, data["score"], " ".join(data["incorrect_tags"]))
else:
print(data["score"], end=",")
print("...", end="")
for good_md5_hash in sorted(
model_scores[model].keys(),
key=lambda md5_hash: model_scores[model][md5_hash]["score"],
reverse=True,
)[:4]:
data = model_scores[model][good_md5_hash]
print(data["score"], end=",")
print("]")
print("most incorrect tags", incorrect_tags_counters[model].most_common(5))
PLOTS = Path.cwd() / "plots"
PLOTS.mkdir(exist_ok=True)
log.info("plotting score histogram...")
data_for_df = {}
data_for_df["scores"] = []
data_for_df["model"] = []
for model in sorted(
normalized_scores.keys(),
key=lambda model: normalized_scores[model],
reverse=True,
):
for post_score in (d["score"] for d in model_scores[model].values()):
data_for_df["scores"].append(post_score)
data_for_df["model"].append(model)
df = pd.DataFrame(data_for_df)
fig = px.histogram(
df,
x="scores",
color="model",
histfunc="count",
marginal="rug",
histnorm="probability",
)
pio.write_image(fig, PLOTS / "score_histogram.png", width=1024, height=800)
log.info("plotting positive histogram...")
plot2(PLOTS / "positive_score_histogram.png", normalized_scores, model_scores)
log.info("plotting error rates...")
plot3(PLOTS / "error_rate.png", normalized_scores, model_scores)
def plot2(output_path, normalized_scores, model_scores):
data_for_df = {}
data_for_df["scores"] = []
data_for_df["model"] = []
for model in sorted(
normalized_scores.keys(),
key=lambda model: normalized_scores[model],
reverse=True,
):
for post_score in (d["score"] for d in model_scores[model].values()):
if post_score < 0:
continue
data_for_df["scores"].append(post_score)
data_for_df["model"].append(model)
df = pd.DataFrame(data_for_df)
fig = px.histogram(
df,
x="scores",
color="model",
histfunc="count",
marginal="rug",
histnorm="probability",
)
pio.write_image(fig, output_path, width=1024, height=800)
def plot3(output_path, normalized_scores, model_scores):
data_for_df = {"model": [], "errors": [], "rating_errors": []}
for model in sorted(
normalized_scores.keys(),
key=lambda model: normalized_scores[model],
reverse=True,
):
total_incorrect_tags = 0
total_rating_errors = 0
for score_data in model_scores[model].values():
total_incorrect_tags += len(score_data["incorrect_tags"])
total_rating_errors += sum(
1
for rating in ["general", "sensitive", "questionable", "explicit"]
if rating in score_data["incorrect_tags"]
)
data_for_df["errors"].append(total_incorrect_tags)
data_for_df["rating_errors"].append(total_rating_errors)
data_for_df["model"].append(model)
df = pd.DataFrame(data_for_df)
fig = go.Figure(
data=[
go.Bar(name="incorrect tags", x=df.model, y=df.errors),
go.Bar(name="incorrect ratings", x=df.model, y=df.rating_errors),
]
)
pio.write_image(fig, output_path, width=1024, height=800)
async def realmain(ctx):
await ctx.db.executescript(
"""
create table if not exists posts (
md5 text primary key,
filepath text,
data text
);
create table if not exists interrogated_posts (
md5 text,
model_name text not null,
output_tag_string text not null,
time_taken real not null,
primary key (md5, model_name)
);
"""
)
try:
mode = sys.argv[1]
except IndexError:
raise Exception("must have mode")
if mode == "download_images":
await download_images(ctx)
elif mode == "fight":
await fight(ctx)
elif mode == "scores":
await scores(ctx)
else:
raise AssertionError(f"invalid mode {mode}")
async def main():
CONFIG_PATH = Path.cwd() / "config.json"
log.info("hewo")
async with aiosqlite.connect("./data.db") as db:
async with aiohttp.ClientSession() as session:
with CONFIG_PATH.open() as config_fd:
config_json = json.load(config_fd)
config = Config(**config_json)
ctx = Context(db, session, config)
await realmain(ctx)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(main())