Compare commits

...

2 Commits

Author SHA1 Message Date
Luna dbf458eed7 add fight mode between tagger models 2023-06-09 23:53:04 -03:00
Luna 6d113e77b8 add download_images function 2023-06-09 23:19:15 -03:00
5 changed files with 307 additions and 1 deletions

3
.gitignore vendored
View File

@ -160,3 +160,6 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
config.json
data.db
posts/

View File

@ -1,3 +1,19 @@
# tagger-showdown
attempting to create a more readable evaluation to anime tagger ai systems
attempting to create a more readable evaluation to anime tagger ai systems
idea: take some recent images from danbooru, also include your own
then run x tagger systems against each other
score formula:
(len(tags in ground_truth) - len(tags not in ground_truth)) / len(ground_truth)
then average for all posts
```sh
python3 -m venv env
env/bin/pip install -Ur ./requirements.txt
env/bin/python3 ./main.py
```

5
config.example.json Normal file
View File

@ -0,0 +1,5 @@
{
"sd_webui_address": "http://100.101.194.71:7860/",
"dd_address": "http://localhost:4443",
"dd_model_name": "hydrus-dd_model.h5_0c0b84c2436489eda29ccc9ee4827b48"
}

279
main.py Normal file
View File

@ -0,0 +1,279 @@
import json
import aiohttp
import aiofiles
import sys
import asyncio
import aiosqlite
import base64
import logging
from collections import defaultdict
from pathlib import Path
from urllib.parse import urlparse
from typing import Any, Optional, List
from aiolimiter import AsyncLimiter
from dataclasses import dataclass, field
log = logging.getLogger(__name__)
DEFAULTS = [
"wd14-vit-v2-git",
"wd14-vit",
"wd14-swinv2-v2-git",
"wd14-convnextv2-v2-git",
"wd14-convnext-v2-git",
"wd14-convnext",
]
@dataclass
class Interrogator:
model_id: str
address: str
class DDInterrogator(Interrogator):
async def interrogate(self, session, path):
async with session.post(
f"{self.address}/",
params={
"threshold": "0.7",
},
headers={"Authorization": "Bearer sex"},
data={"file": path.open("rb")},
) as resp:
assert resp.status == 200
tags = await resp.json()
upstream_tags = [tag.replace(" ", "_") for tag in tags]
return " ".join(upstream_tags)
class SDInterrogator(Interrogator):
async def interrogate(self, session, path):
async with aiofiles.open(path, "rb") as fd:
as_base64 = base64.b64encode(await fd.read()).decode("utf-8")
async with session.post(
f"{self.address}/tagger/v1/interrogate",
json={
"model": self.model_id,
"threshold": 0.7,
"image": as_base64,
},
) as resp:
log.info("got %d", resp.status)
assert resp.status == 200
data = await resp.json()
tags_with_scores = data["caption"]
tags = list(tags_with_scores.keys())
upstream_tags = [tag.replace(" ", "_") for tag in tags]
return " ".join(upstream_tags)
@dataclass
class Config:
sd_webui_address: str
dd_address: str
dd_model_name: str
sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS))
@property
def all_available_models(self) -> List[Any]:
return [
SDInterrogator(sd_interrogator, self.sd_webui_address)
for sd_interrogator in self.sd_webui_models
] + [DDInterrogator(self.dd_model_name, self.dd_address)]
@dataclass
class Context:
db: Any
session: aiohttp.ClientSession
config: Config
@dataclass
class Booru:
session: aiohttp.ClientSession
limiter: AsyncLimiter
tag_limiter: AsyncLimiter
file_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock))
tag_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock))
fetch_type = "hash"
@property
def hash_style(self):
return HashStyle.md5
class Danbooru(Booru):
title = "Danbooru"
base_url = "https://danbooru.donmai.us"
async def posts(self, tag_query: str, limit):
log.info("%s: submit %r", self.title, tag_query)
async with self.limiter:
log.info("%s: submit upstream %r", self.title, tag_query)
async with self.session.get(
f"{self.base_url}/posts.json",
params={"tags": tag_query, "limit": limit},
) as resp:
assert resp.status == 200
rjson = await resp.json()
return rjson
DOWNLOADS = Path.cwd() / "posts"
DOWNLOADS.mkdir(exist_ok=True)
async def download_images(ctx):
try:
tagquery = sys.argv[2]
except IndexError:
tagquery = ""
try:
limit = int(sys.argv[3])
except IndexError:
limit = 10
danbooru = Danbooru(
ctx.session,
AsyncLimiter(1, 10),
AsyncLimiter(1, 3),
)
posts = await danbooru.posts(tagquery, limit)
for post in posts:
if "md5" not in post:
continue
log.info("processing post %r", post)
existing_post = await fetch_post(ctx, post["md5"])
if existing_post:
continue
# download the post
post_file_url = post["file_url"]
post_filename = Path(urlparse(post_file_url).path).name
post_filepath = DOWNLOADS / post_filename
if not post_filepath.exists():
log.info("downloading %r to %r", post_file_url, post_filepath)
async with ctx.session.get(post_file_url) as resp:
assert resp.status == 200
with post_filepath.open("wb") as fd:
async for chunk in resp.content.iter_chunked(1024):
fd.write(chunk)
# when it's done successfully, insert it
await insert_post(ctx, post)
async def fetch_post(ctx, md5) -> Optional[dict]:
rows = await ctx.db.execute_fetchall("select data from posts where md5 = ?", (md5,))
if not rows:
return None
assert len(rows) == 1
return json.loads(rows[0][0])
async def insert_post(ctx, post):
await ctx.db.execute_insert(
"insert into posts (md5, filepath ,data) values (?, ?,?)",
(
post["md5"],
Path(urlparse(post["file_url"]).path).name,
json.dumps(post),
),
)
await ctx.db.commit()
async def insert_interrogated_result(
ctx, interrogator: Interrogator, md5: str, tag_string: str
):
await ctx.db.execute_insert(
"insert into interrogated_posts (md5, model_name, output_tag_string) values (?,?,?)",
(md5, interrogator.model_id, tag_string),
)
await ctx.db.commit()
async def fight(ctx):
interrogators = ctx.config.all_available_models
all_rows = await ctx.db.execute_fetchall("select md5 from posts")
all_hashes = set(r[0] for r in all_rows)
for interrogator in interrogators:
log.info("processing for %r", interrogator)
# calculate set of images we didn't interrogate yet
interrogated_rows = await ctx.db.execute_fetchall(
"select md5 from interrogated_posts where model_name = ?",
(interrogator.model_id,),
)
interrogated_hashes = set(row[0] for row in interrogated_rows)
missing_hashes = all_hashes - interrogated_hashes
log.info("missing %d hashes", len(missing_hashes))
for missing_hash in missing_hashes:
log.info("interrogating %r", missing_hash)
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
tag_string = await interrogator.interrogate(ctx.session, post_filepath)
log.info("got %r", tag_string)
await insert_interrogated_result(
ctx, interrogator, missing_hash, tag_string
)
async def realmain(ctx):
await ctx.db.executescript(
"""
create table if not exists posts (
md5 text primary key,
filepath text,
data text
);
create table if not exists interrogated_posts (
md5 text,
model_name text not null,
output_tag_string text not null,
primary key (md5, model_name)
);
"""
)
try:
mode = sys.argv[1]
except IndexError:
raise Exception("must have mode")
if mode == "download_images":
await download_images(ctx)
elif mode == "fight":
await fight(ctx)
elif mode == "scores":
await scores(ctx)
else:
raise AssertionError(f"invalid mode {mode}")
async def main():
CONFIG_PATH = Path.cwd() / "config.json"
log.info("hewo")
async with aiosqlite.connect("./data.db") as db:
async with aiohttp.ClientSession() as session:
with CONFIG_PATH.open() as config_fd:
config_json = json.load(config_fd)
config = Config(**config_json)
ctx = Context(db, session, config)
await realmain(ctx)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(main())

3
requirements.txt Normal file
View File

@ -0,0 +1,3 @@
aiosqlite==0.19.0
aiohttp==3.8.4
aiolimiter==1.1.0