add download_images function
This commit is contained in:
parent
a26087e079
commit
6d113e77b8
5 changed files with 207 additions and 1 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -160,3 +160,6 @@ cython_debug/
|
|||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
config.json
|
||||
data.db
|
||||
posts/
|
||||
|
|
18
README.md
18
README.md
|
@ -1,3 +1,19 @@
|
|||
# tagger-showdown
|
||||
|
||||
attempting to create a more readable evaluation to anime tagger ai systems
|
||||
attempting to create a more readable evaluation to anime tagger ai systems
|
||||
|
||||
idea: take some recent images from danbooru, also include your own
|
||||
|
||||
then run x tagger systems against each other
|
||||
|
||||
score formula:
|
||||
|
||||
(len(tags in ground_truth) - len(tags not in ground_truth)) / len(ground_truth)
|
||||
|
||||
then average for all posts
|
||||
|
||||
```sh
|
||||
python3 -m venv env
|
||||
env/bin/pip install -Ur ./requirements.txt
|
||||
env/bin/python3 ./main.py
|
||||
```
|
5
config.example.json
Normal file
5
config.example.json
Normal file
|
@ -0,0 +1,5 @@
|
|||
{
|
||||
"sd_webui_address": "http://100.101.194.71:7860/",
|
||||
"dd_address": "http://localhost:4443",
|
||||
"dd_model_name": "hydrus-dd_model.h5_0c0b84c2436489eda29ccc9ee4827b48"
|
||||
}
|
179
main.py
Normal file
179
main.py
Normal file
|
@ -0,0 +1,179 @@
|
|||
import json
|
||||
import aiohttp
|
||||
import sys
|
||||
import asyncio
|
||||
import aiosqlite
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
from typing import Any, Optional, List
|
||||
from aiolimiter import AsyncLimiter
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
DEFAULTS = [
|
||||
"wd14-vit-v2-git",
|
||||
"wd14-vit",
|
||||
"wd14-swinv2-v2-git",
|
||||
"wd14-convnextv2-v2-git",
|
||||
"wd14-convnext-v2-git",
|
||||
"wd14-convnext",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
sd_webui_address: str
|
||||
dd_address: str
|
||||
dd_model_name: str
|
||||
sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
db: Any
|
||||
session: aiohttp.ClientSession
|
||||
config: Config
|
||||
|
||||
|
||||
@dataclass
|
||||
class Booru:
|
||||
session: aiohttp.ClientSession
|
||||
limiter: AsyncLimiter
|
||||
tag_limiter: AsyncLimiter
|
||||
file_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock))
|
||||
tag_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock))
|
||||
|
||||
fetch_type = "hash"
|
||||
|
||||
@property
|
||||
def hash_style(self):
|
||||
return HashStyle.md5
|
||||
|
||||
|
||||
class Danbooru(Booru):
|
||||
title = "Danbooru"
|
||||
base_url = "https://danbooru.donmai.us"
|
||||
|
||||
async def posts(self, tag_query: str, limit):
|
||||
log.info("%s: submit %r", self.title, tag_query)
|
||||
async with self.limiter:
|
||||
log.info("%s: submit upstream %r", self.title, tag_query)
|
||||
async with self.session.get(
|
||||
f"{self.base_url}/posts.json",
|
||||
params={"tags": tag_query, "limit": limit},
|
||||
) as resp:
|
||||
assert resp.status == 200
|
||||
rjson = await resp.json()
|
||||
return rjson
|
||||
|
||||
|
||||
DOWNLOADS = Path.cwd() / "posts"
|
||||
DOWNLOADS.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
async def download_images(ctx):
|
||||
try:
|
||||
tagquery = sys.argv[2]
|
||||
except IndexError:
|
||||
tagquery = ""
|
||||
|
||||
try:
|
||||
limit = int(sys.argv[3])
|
||||
except IndexError:
|
||||
limit = 10
|
||||
|
||||
danbooru = Danbooru(
|
||||
ctx.session,
|
||||
AsyncLimiter(1, 10),
|
||||
AsyncLimiter(1, 3),
|
||||
)
|
||||
|
||||
posts = await danbooru.posts(tagquery, limit)
|
||||
for post in posts:
|
||||
if "md5" not in post:
|
||||
continue
|
||||
log.info("processing post %r", post)
|
||||
existing_post = await fetch_post(ctx, post["md5"])
|
||||
if existing_post:
|
||||
continue
|
||||
|
||||
# download the post
|
||||
post_file_url = post["file_url"]
|
||||
post_filename = Path(urlparse(post_file_url).path).name
|
||||
post_filepath = DOWNLOADS / post_filename
|
||||
if not post_filepath.exists():
|
||||
log.info("downloading %r to %r", post_file_url, post_filepath)
|
||||
async with ctx.session.get(post_file_url) as resp:
|
||||
assert resp.status == 200
|
||||
with post_filepath.open("wb") as fd:
|
||||
async for chunk in resp.content.iter_chunked(1024):
|
||||
fd.write(chunk)
|
||||
|
||||
# when it's done successfully, insert it
|
||||
await insert_post(ctx, post)
|
||||
|
||||
|
||||
async def fetch_post(ctx, md5) -> Optional[dict]:
|
||||
rows = await ctx.db.execute_fetchall("select data from posts where md5 = ?", (md5,))
|
||||
if not rows:
|
||||
return None
|
||||
assert len(rows) == 1
|
||||
return json.loads(rows[0][0])
|
||||
|
||||
|
||||
async def insert_post(ctx, post):
|
||||
await ctx.db.execute_insert(
|
||||
"insert into posts (md5, filepath ,data) values (?, ?,?)",
|
||||
(
|
||||
post["md5"],
|
||||
Path(urlparse(post["file_url"]).path).name,
|
||||
json.dumps(post),
|
||||
),
|
||||
)
|
||||
await ctx.db.commit()
|
||||
|
||||
|
||||
async def fight(ctx):
|
||||
pass
|
||||
|
||||
|
||||
async def realmain(ctx):
|
||||
await ctx.db.executescript(
|
||||
"""
|
||||
create table if not exists posts (md5 text primary key, filepath text, data text);
|
||||
"""
|
||||
)
|
||||
|
||||
try:
|
||||
mode = sys.argv[1]
|
||||
except IndexError:
|
||||
raise Exception("must have mode")
|
||||
|
||||
if mode == "download_images":
|
||||
await download_images(ctx)
|
||||
elif mode == "fight":
|
||||
await fight(ctx)
|
||||
else:
|
||||
raise AssertionError(f"invalid mode {mode}")
|
||||
|
||||
|
||||
async def main():
|
||||
CONFIG_PATH = Path.cwd() / "config.json"
|
||||
log.info("hewo")
|
||||
async with aiosqlite.connect("./data.db") as db:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
with CONFIG_PATH.open() as config_fd:
|
||||
config_json = json.load(config_fd)
|
||||
|
||||
config = Config(**config_json)
|
||||
ctx = Context(db, session, config)
|
||||
await realmain(ctx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
asyncio.run(main())
|
3
requirements.txt
Normal file
3
requirements.txt
Normal file
|
@ -0,0 +1,3 @@
|
|||
aiosqlite==0.19.0
|
||||
aiohttp==3.8.4
|
||||
aiolimiter==1.1.0
|
Loading…
Reference in a new issue