From 6d113e77b83b90dcb8a8e5668793983472cff100 Mon Sep 17 00:00:00 2001 From: Luna Date: Fri, 9 Jun 2023 23:19:15 -0300 Subject: [PATCH] add download_images function --- .gitignore | 3 + README.md | 18 ++++- config.example.json | 5 ++ main.py | 179 ++++++++++++++++++++++++++++++++++++++++++++ requirements.txt | 3 + 5 files changed, 207 insertions(+), 1 deletion(-) create mode 100644 config.example.json create mode 100644 main.py create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore index 5d381cc..f7ab5d7 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,6 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +config.json +data.db +posts/ diff --git a/README.md b/README.md index c0ab80d..fe130f0 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,19 @@ # tagger-showdown -attempting to create a more readable evaluation to anime tagger ai systems \ No newline at end of file +attempting to create a more readable evaluation to anime tagger ai systems + +idea: take some recent images from danbooru, also include your own + +then run x tagger systems against each other + +score formula: + +(len(tags in ground_truth) - len(tags not in ground_truth)) / len(ground_truth) + +then average for all posts + +```sh +python3 -m venv env +env/bin/pip install -Ur ./requirements.txt +env/bin/python3 ./main.py +``` \ No newline at end of file diff --git a/config.example.json b/config.example.json new file mode 100644 index 0000000..add9e23 --- /dev/null +++ b/config.example.json @@ -0,0 +1,5 @@ +{ + "sd_webui_address": "http://100.101.194.71:7860/", + "dd_address": "http://localhost:4443", + "dd_model_name": "hydrus-dd_model.h5_0c0b84c2436489eda29ccc9ee4827b48" +} diff --git a/main.py b/main.py new file mode 100644 index 0000000..624143e --- /dev/null +++ b/main.py @@ -0,0 +1,179 @@ +import json +import aiohttp +import sys +import asyncio +import aiosqlite +import logging +from collections import defaultdict +from pathlib import Path +from urllib.parse import urlparse +from typing import Any, Optional, List +from aiolimiter import AsyncLimiter +from dataclasses import dataclass, field + +log = logging.getLogger(__name__) + + +DEFAULTS = [ + "wd14-vit-v2-git", + "wd14-vit", + "wd14-swinv2-v2-git", + "wd14-convnextv2-v2-git", + "wd14-convnext-v2-git", + "wd14-convnext", +] + + +@dataclass +class Config: + sd_webui_address: str + dd_address: str + dd_model_name: str + sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS)) + + +@dataclass +class Context: + db: Any + session: aiohttp.ClientSession + config: Config + + +@dataclass +class Booru: + session: aiohttp.ClientSession + limiter: AsyncLimiter + tag_limiter: AsyncLimiter + file_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock)) + tag_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock)) + + fetch_type = "hash" + + @property + def hash_style(self): + return HashStyle.md5 + + +class Danbooru(Booru): + title = "Danbooru" + base_url = "https://danbooru.donmai.us" + + async def posts(self, tag_query: str, limit): + log.info("%s: submit %r", self.title, tag_query) + async with self.limiter: + log.info("%s: submit upstream %r", self.title, tag_query) + async with self.session.get( + f"{self.base_url}/posts.json", + params={"tags": tag_query, "limit": limit}, + ) as resp: + assert resp.status == 200 + rjson = await resp.json() + return rjson + + +DOWNLOADS = Path.cwd() / "posts" +DOWNLOADS.mkdir(exist_ok=True) + + +async def download_images(ctx): + try: + tagquery = sys.argv[2] + except IndexError: + tagquery = "" + + try: + limit = int(sys.argv[3]) + except IndexError: + limit = 10 + + danbooru = Danbooru( + ctx.session, + AsyncLimiter(1, 10), + AsyncLimiter(1, 3), + ) + + posts = await danbooru.posts(tagquery, limit) + for post in posts: + if "md5" not in post: + continue + log.info("processing post %r", post) + existing_post = await fetch_post(ctx, post["md5"]) + if existing_post: + continue + + # download the post + post_file_url = post["file_url"] + post_filename = Path(urlparse(post_file_url).path).name + post_filepath = DOWNLOADS / post_filename + if not post_filepath.exists(): + log.info("downloading %r to %r", post_file_url, post_filepath) + async with ctx.session.get(post_file_url) as resp: + assert resp.status == 200 + with post_filepath.open("wb") as fd: + async for chunk in resp.content.iter_chunked(1024): + fd.write(chunk) + + # when it's done successfully, insert it + await insert_post(ctx, post) + + +async def fetch_post(ctx, md5) -> Optional[dict]: + rows = await ctx.db.execute_fetchall("select data from posts where md5 = ?", (md5,)) + if not rows: + return None + assert len(rows) == 1 + return json.loads(rows[0][0]) + + +async def insert_post(ctx, post): + await ctx.db.execute_insert( + "insert into posts (md5, filepath ,data) values (?, ?,?)", + ( + post["md5"], + Path(urlparse(post["file_url"]).path).name, + json.dumps(post), + ), + ) + await ctx.db.commit() + + +async def fight(ctx): + pass + + +async def realmain(ctx): + await ctx.db.executescript( + """ + create table if not exists posts (md5 text primary key, filepath text, data text); + """ + ) + + try: + mode = sys.argv[1] + except IndexError: + raise Exception("must have mode") + + if mode == "download_images": + await download_images(ctx) + elif mode == "fight": + await fight(ctx) + else: + raise AssertionError(f"invalid mode {mode}") + + +async def main(): + CONFIG_PATH = Path.cwd() / "config.json" + log.info("hewo") + async with aiosqlite.connect("./data.db") as db: + async with aiohttp.ClientSession() as session: + with CONFIG_PATH.open() as config_fd: + config_json = json.load(config_fd) + + config = Config(**config_json) + ctx = Context(db, session, config) + await realmain(ctx) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + asyncio.run(main()) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..db0e4df --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +aiosqlite==0.19.0 +aiohttp==3.8.4 +aiolimiter==1.1.0 \ No newline at end of file