import json import aiohttp import sys import asyncio import aiosqlite import logging from collections import defaultdict from pathlib import Path from urllib.parse import urlparse from typing import Any, Optional, List from aiolimiter import AsyncLimiter from dataclasses import dataclass, field log = logging.getLogger(__name__) DEFAULTS = [ "wd14-vit-v2-git", "wd14-vit", "wd14-swinv2-v2-git", "wd14-convnextv2-v2-git", "wd14-convnext-v2-git", "wd14-convnext", ] @dataclass class Config: sd_webui_address: str dd_address: str dd_model_name: str sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS)) @dataclass class Context: db: Any session: aiohttp.ClientSession config: Config @dataclass class Booru: session: aiohttp.ClientSession limiter: AsyncLimiter tag_limiter: AsyncLimiter file_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock)) tag_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock)) fetch_type = "hash" @property def hash_style(self): return HashStyle.md5 class Danbooru(Booru): title = "Danbooru" base_url = "https://danbooru.donmai.us" async def posts(self, tag_query: str, limit): log.info("%s: submit %r", self.title, tag_query) async with self.limiter: log.info("%s: submit upstream %r", self.title, tag_query) async with self.session.get( f"{self.base_url}/posts.json", params={"tags": tag_query, "limit": limit}, ) as resp: assert resp.status == 200 rjson = await resp.json() return rjson DOWNLOADS = Path.cwd() / "posts" DOWNLOADS.mkdir(exist_ok=True) async def download_images(ctx): try: tagquery = sys.argv[2] except IndexError: tagquery = "" try: limit = int(sys.argv[3]) except IndexError: limit = 10 danbooru = Danbooru( ctx.session, AsyncLimiter(1, 10), AsyncLimiter(1, 3), ) posts = await danbooru.posts(tagquery, limit) for post in posts: if "md5" not in post: continue log.info("processing post %r", post) existing_post = await fetch_post(ctx, post["md5"]) if existing_post: continue # download the post post_file_url = post["file_url"] post_filename = Path(urlparse(post_file_url).path).name post_filepath = DOWNLOADS / post_filename if not post_filepath.exists(): log.info("downloading %r to %r", post_file_url, post_filepath) async with ctx.session.get(post_file_url) as resp: assert resp.status == 200 with post_filepath.open("wb") as fd: async for chunk in resp.content.iter_chunked(1024): fd.write(chunk) # when it's done successfully, insert it await insert_post(ctx, post) async def fetch_post(ctx, md5) -> Optional[dict]: rows = await ctx.db.execute_fetchall("select data from posts where md5 = ?", (md5,)) if not rows: return None assert len(rows) == 1 return json.loads(rows[0][0]) async def insert_post(ctx, post): await ctx.db.execute_insert( "insert into posts (md5, filepath ,data) values (?, ?,?)", ( post["md5"], Path(urlparse(post["file_url"]).path).name, json.dumps(post), ), ) await ctx.db.commit() async def fight(ctx): pass async def realmain(ctx): await ctx.db.executescript( """ create table if not exists posts (md5 text primary key, filepath text, data text); """ ) try: mode = sys.argv[1] except IndexError: raise Exception("must have mode") if mode == "download_images": await download_images(ctx) elif mode == "fight": await fight(ctx) else: raise AssertionError(f"invalid mode {mode}") async def main(): CONFIG_PATH = Path.cwd() / "config.json" log.info("hewo") async with aiosqlite.connect("./data.db") as db: async with aiohttp.ClientSession() as session: with CONFIG_PATH.open() as config_fd: config_json = json.load(config_fd) config = Config(**config_json) ctx = Context(db, session, config) await realmain(ctx) if __name__ == "__main__": logging.basicConfig(level=logging.INFO) asyncio.run(main())