add download_images function

This commit is contained in:
Luna 2023-06-09 23:19:15 -03:00
parent a26087e079
commit 6d113e77b8
5 changed files with 207 additions and 1 deletions

3
.gitignore vendored
View file

@ -160,3 +160,6 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
config.json
data.db
posts/

View file

@ -1,3 +1,19 @@
# tagger-showdown
attempting to create a more readable evaluation to anime tagger ai systems
attempting to create a more readable evaluation to anime tagger ai systems
idea: take some recent images from danbooru, also include your own
then run x tagger systems against each other
score formula:
(len(tags in ground_truth) - len(tags not in ground_truth)) / len(ground_truth)
then average for all posts
```sh
python3 -m venv env
env/bin/pip install -Ur ./requirements.txt
env/bin/python3 ./main.py
```

5
config.example.json Normal file
View file

@ -0,0 +1,5 @@
{
"sd_webui_address": "http://100.101.194.71:7860/",
"dd_address": "http://localhost:4443",
"dd_model_name": "hydrus-dd_model.h5_0c0b84c2436489eda29ccc9ee4827b48"
}

179
main.py Normal file
View file

@ -0,0 +1,179 @@
import json
import aiohttp
import sys
import asyncio
import aiosqlite
import logging
from collections import defaultdict
from pathlib import Path
from urllib.parse import urlparse
from typing import Any, Optional, List
from aiolimiter import AsyncLimiter
from dataclasses import dataclass, field
log = logging.getLogger(__name__)
DEFAULTS = [
"wd14-vit-v2-git",
"wd14-vit",
"wd14-swinv2-v2-git",
"wd14-convnextv2-v2-git",
"wd14-convnext-v2-git",
"wd14-convnext",
]
@dataclass
class Config:
sd_webui_address: str
dd_address: str
dd_model_name: str
sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS))
@dataclass
class Context:
db: Any
session: aiohttp.ClientSession
config: Config
@dataclass
class Booru:
session: aiohttp.ClientSession
limiter: AsyncLimiter
tag_limiter: AsyncLimiter
file_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock))
tag_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock))
fetch_type = "hash"
@property
def hash_style(self):
return HashStyle.md5
class Danbooru(Booru):
title = "Danbooru"
base_url = "https://danbooru.donmai.us"
async def posts(self, tag_query: str, limit):
log.info("%s: submit %r", self.title, tag_query)
async with self.limiter:
log.info("%s: submit upstream %r", self.title, tag_query)
async with self.session.get(
f"{self.base_url}/posts.json",
params={"tags": tag_query, "limit": limit},
) as resp:
assert resp.status == 200
rjson = await resp.json()
return rjson
DOWNLOADS = Path.cwd() / "posts"
DOWNLOADS.mkdir(exist_ok=True)
async def download_images(ctx):
try:
tagquery = sys.argv[2]
except IndexError:
tagquery = ""
try:
limit = int(sys.argv[3])
except IndexError:
limit = 10
danbooru = Danbooru(
ctx.session,
AsyncLimiter(1, 10),
AsyncLimiter(1, 3),
)
posts = await danbooru.posts(tagquery, limit)
for post in posts:
if "md5" not in post:
continue
log.info("processing post %r", post)
existing_post = await fetch_post(ctx, post["md5"])
if existing_post:
continue
# download the post
post_file_url = post["file_url"]
post_filename = Path(urlparse(post_file_url).path).name
post_filepath = DOWNLOADS / post_filename
if not post_filepath.exists():
log.info("downloading %r to %r", post_file_url, post_filepath)
async with ctx.session.get(post_file_url) as resp:
assert resp.status == 200
with post_filepath.open("wb") as fd:
async for chunk in resp.content.iter_chunked(1024):
fd.write(chunk)
# when it's done successfully, insert it
await insert_post(ctx, post)
async def fetch_post(ctx, md5) -> Optional[dict]:
rows = await ctx.db.execute_fetchall("select data from posts where md5 = ?", (md5,))
if not rows:
return None
assert len(rows) == 1
return json.loads(rows[0][0])
async def insert_post(ctx, post):
await ctx.db.execute_insert(
"insert into posts (md5, filepath ,data) values (?, ?,?)",
(
post["md5"],
Path(urlparse(post["file_url"]).path).name,
json.dumps(post),
),
)
await ctx.db.commit()
async def fight(ctx):
pass
async def realmain(ctx):
await ctx.db.executescript(
"""
create table if not exists posts (md5 text primary key, filepath text, data text);
"""
)
try:
mode = sys.argv[1]
except IndexError:
raise Exception("must have mode")
if mode == "download_images":
await download_images(ctx)
elif mode == "fight":
await fight(ctx)
else:
raise AssertionError(f"invalid mode {mode}")
async def main():
CONFIG_PATH = Path.cwd() / "config.json"
log.info("hewo")
async with aiosqlite.connect("./data.db") as db:
async with aiohttp.ClientSession() as session:
with CONFIG_PATH.open() as config_fd:
config_json = json.load(config_fd)
config = Config(**config_json)
ctx = Context(db, session, config)
await realmain(ctx)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
asyncio.run(main())

3
requirements.txt Normal file
View file

@ -0,0 +1,3 @@
aiosqlite==0.19.0
aiohttp==3.8.4
aiolimiter==1.1.0