diff --git a/.gitignore b/.gitignore index f7ab5d7..5d381cc 100644 --- a/.gitignore +++ b/.gitignore @@ -160,6 +160,3 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ -config.json -data.db -posts/ diff --git a/README.md b/README.md index fe130f0..c0ab80d 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,3 @@ # tagger-showdown -attempting to create a more readable evaluation to anime tagger ai systems - -idea: take some recent images from danbooru, also include your own - -then run x tagger systems against each other - -score formula: - -(len(tags in ground_truth) - len(tags not in ground_truth)) / len(ground_truth) - -then average for all posts - -```sh -python3 -m venv env -env/bin/pip install -Ur ./requirements.txt -env/bin/python3 ./main.py -``` \ No newline at end of file +attempting to create a more readable evaluation to anime tagger ai systems \ No newline at end of file diff --git a/config.example.json b/config.example.json deleted file mode 100644 index add9e23..0000000 --- a/config.example.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "sd_webui_address": "http://100.101.194.71:7860/", - "dd_address": "http://localhost:4443", - "dd_model_name": "hydrus-dd_model.h5_0c0b84c2436489eda29ccc9ee4827b48" -} diff --git a/main.py b/main.py deleted file mode 100644 index 922421c..0000000 --- a/main.py +++ /dev/null @@ -1,279 +0,0 @@ -import json -import aiohttp -import aiofiles -import sys -import asyncio -import aiosqlite -import base64 -import logging -from collections import defaultdict -from pathlib import Path -from urllib.parse import urlparse -from typing import Any, Optional, List -from aiolimiter import AsyncLimiter -from dataclasses import dataclass, field - -log = logging.getLogger(__name__) - - -DEFAULTS = [ - "wd14-vit-v2-git", - "wd14-vit", - "wd14-swinv2-v2-git", - "wd14-convnextv2-v2-git", - "wd14-convnext-v2-git", - "wd14-convnext", -] - - -@dataclass -class Interrogator: - model_id: str - address: str - - -class DDInterrogator(Interrogator): - async def interrogate(self, session, path): - async with session.post( - f"{self.address}/", - params={ - "threshold": "0.7", - }, - headers={"Authorization": "Bearer sex"}, - data={"file": path.open("rb")}, - ) as resp: - assert resp.status == 200 - tags = await resp.json() - upstream_tags = [tag.replace(" ", "_") for tag in tags] - return " ".join(upstream_tags) - - -class SDInterrogator(Interrogator): - async def interrogate(self, session, path): - async with aiofiles.open(path, "rb") as fd: - as_base64 = base64.b64encode(await fd.read()).decode("utf-8") - - async with session.post( - f"{self.address}/tagger/v1/interrogate", - json={ - "model": self.model_id, - "threshold": 0.7, - "image": as_base64, - }, - ) as resp: - log.info("got %d", resp.status) - assert resp.status == 200 - data = await resp.json() - tags_with_scores = data["caption"] - tags = list(tags_with_scores.keys()) - - upstream_tags = [tag.replace(" ", "_") for tag in tags] - return " ".join(upstream_tags) - - -@dataclass -class Config: - sd_webui_address: str - dd_address: str - dd_model_name: str - sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS)) - - @property - def all_available_models(self) -> List[Any]: - return [ - SDInterrogator(sd_interrogator, self.sd_webui_address) - for sd_interrogator in self.sd_webui_models - ] + [DDInterrogator(self.dd_model_name, self.dd_address)] - - -@dataclass -class Context: - db: Any - session: aiohttp.ClientSession - config: Config - - -@dataclass -class Booru: - session: aiohttp.ClientSession - limiter: AsyncLimiter - tag_limiter: AsyncLimiter - file_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock)) - tag_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock)) - - fetch_type = "hash" - - @property - def hash_style(self): - return HashStyle.md5 - - -class Danbooru(Booru): - title = "Danbooru" - base_url = "https://danbooru.donmai.us" - - async def posts(self, tag_query: str, limit): - log.info("%s: submit %r", self.title, tag_query) - async with self.limiter: - log.info("%s: submit upstream %r", self.title, tag_query) - async with self.session.get( - f"{self.base_url}/posts.json", - params={"tags": tag_query, "limit": limit}, - ) as resp: - assert resp.status == 200 - rjson = await resp.json() - return rjson - - -DOWNLOADS = Path.cwd() / "posts" -DOWNLOADS.mkdir(exist_ok=True) - - -async def download_images(ctx): - try: - tagquery = sys.argv[2] - except IndexError: - tagquery = "" - - try: - limit = int(sys.argv[3]) - except IndexError: - limit = 10 - - danbooru = Danbooru( - ctx.session, - AsyncLimiter(1, 10), - AsyncLimiter(1, 3), - ) - - posts = await danbooru.posts(tagquery, limit) - for post in posts: - if "md5" not in post: - continue - log.info("processing post %r", post) - existing_post = await fetch_post(ctx, post["md5"]) - if existing_post: - continue - - # download the post - post_file_url = post["file_url"] - post_filename = Path(urlparse(post_file_url).path).name - post_filepath = DOWNLOADS / post_filename - if not post_filepath.exists(): - log.info("downloading %r to %r", post_file_url, post_filepath) - async with ctx.session.get(post_file_url) as resp: - assert resp.status == 200 - with post_filepath.open("wb") as fd: - async for chunk in resp.content.iter_chunked(1024): - fd.write(chunk) - - # when it's done successfully, insert it - await insert_post(ctx, post) - - -async def fetch_post(ctx, md5) -> Optional[dict]: - rows = await ctx.db.execute_fetchall("select data from posts where md5 = ?", (md5,)) - if not rows: - return None - assert len(rows) == 1 - return json.loads(rows[0][0]) - - -async def insert_post(ctx, post): - await ctx.db.execute_insert( - "insert into posts (md5, filepath ,data) values (?, ?,?)", - ( - post["md5"], - Path(urlparse(post["file_url"]).path).name, - json.dumps(post), - ), - ) - await ctx.db.commit() - - -async def insert_interrogated_result( - ctx, interrogator: Interrogator, md5: str, tag_string: str -): - await ctx.db.execute_insert( - "insert into interrogated_posts (md5, model_name, output_tag_string) values (?,?,?)", - (md5, interrogator.model_id, tag_string), - ) - await ctx.db.commit() - - -async def fight(ctx): - interrogators = ctx.config.all_available_models - - all_rows = await ctx.db.execute_fetchall("select md5 from posts") - all_hashes = set(r[0] for r in all_rows) - - for interrogator in interrogators: - log.info("processing for %r", interrogator) - # calculate set of images we didn't interrogate yet - interrogated_rows = await ctx.db.execute_fetchall( - "select md5 from interrogated_posts where model_name = ?", - (interrogator.model_id,), - ) - interrogated_hashes = set(row[0] for row in interrogated_rows) - missing_hashes = all_hashes - interrogated_hashes - - log.info("missing %d hashes", len(missing_hashes)) - - for missing_hash in missing_hashes: - log.info("interrogating %r", missing_hash) - post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*")) - tag_string = await interrogator.interrogate(ctx.session, post_filepath) - log.info("got %r", tag_string) - await insert_interrogated_result( - ctx, interrogator, missing_hash, tag_string - ) - - -async def realmain(ctx): - await ctx.db.executescript( - """ - create table if not exists posts ( - md5 text primary key, - filepath text, - data text - ); - create table if not exists interrogated_posts ( - md5 text, - model_name text not null, - output_tag_string text not null, - primary key (md5, model_name) - ); - """ - ) - - try: - mode = sys.argv[1] - except IndexError: - raise Exception("must have mode") - - if mode == "download_images": - await download_images(ctx) - elif mode == "fight": - await fight(ctx) - elif mode == "scores": - await scores(ctx) - else: - raise AssertionError(f"invalid mode {mode}") - - -async def main(): - CONFIG_PATH = Path.cwd() / "config.json" - log.info("hewo") - async with aiosqlite.connect("./data.db") as db: - async with aiohttp.ClientSession() as session: - with CONFIG_PATH.open() as config_fd: - config_json = json.load(config_fd) - - config = Config(**config_json) - ctx = Context(db, session, config) - await realmain(ctx) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - asyncio.run(main()) diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index db0e4df..0000000 --- a/requirements.txt +++ /dev/null @@ -1,3 +0,0 @@ -aiosqlite==0.19.0 -aiohttp==3.8.4 -aiolimiter==1.1.0 \ No newline at end of file