179 lines
4.5 KiB
Python
179 lines
4.5 KiB
Python
import json
|
|
import aiohttp
|
|
import sys
|
|
import asyncio
|
|
import aiosqlite
|
|
import logging
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
from urllib.parse import urlparse
|
|
from typing import Any, Optional, List
|
|
from aiolimiter import AsyncLimiter
|
|
from dataclasses import dataclass, field
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
DEFAULTS = [
|
|
"wd14-vit-v2-git",
|
|
"wd14-vit",
|
|
"wd14-swinv2-v2-git",
|
|
"wd14-convnextv2-v2-git",
|
|
"wd14-convnext-v2-git",
|
|
"wd14-convnext",
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class Config:
|
|
sd_webui_address: str
|
|
dd_address: str
|
|
dd_model_name: str
|
|
sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS))
|
|
|
|
|
|
@dataclass
|
|
class Context:
|
|
db: Any
|
|
session: aiohttp.ClientSession
|
|
config: Config
|
|
|
|
|
|
@dataclass
|
|
class Booru:
|
|
session: aiohttp.ClientSession
|
|
limiter: AsyncLimiter
|
|
tag_limiter: AsyncLimiter
|
|
file_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock))
|
|
tag_locks: dict = field(default_factory=lambda: defaultdict(asyncio.Lock))
|
|
|
|
fetch_type = "hash"
|
|
|
|
@property
|
|
def hash_style(self):
|
|
return HashStyle.md5
|
|
|
|
|
|
class Danbooru(Booru):
|
|
title = "Danbooru"
|
|
base_url = "https://danbooru.donmai.us"
|
|
|
|
async def posts(self, tag_query: str, limit):
|
|
log.info("%s: submit %r", self.title, tag_query)
|
|
async with self.limiter:
|
|
log.info("%s: submit upstream %r", self.title, tag_query)
|
|
async with self.session.get(
|
|
f"{self.base_url}/posts.json",
|
|
params={"tags": tag_query, "limit": limit},
|
|
) as resp:
|
|
assert resp.status == 200
|
|
rjson = await resp.json()
|
|
return rjson
|
|
|
|
|
|
DOWNLOADS = Path.cwd() / "posts"
|
|
DOWNLOADS.mkdir(exist_ok=True)
|
|
|
|
|
|
async def download_images(ctx):
|
|
try:
|
|
tagquery = sys.argv[2]
|
|
except IndexError:
|
|
tagquery = ""
|
|
|
|
try:
|
|
limit = int(sys.argv[3])
|
|
except IndexError:
|
|
limit = 10
|
|
|
|
danbooru = Danbooru(
|
|
ctx.session,
|
|
AsyncLimiter(1, 10),
|
|
AsyncLimiter(1, 3),
|
|
)
|
|
|
|
posts = await danbooru.posts(tagquery, limit)
|
|
for post in posts:
|
|
if "md5" not in post:
|
|
continue
|
|
log.info("processing post %r", post)
|
|
existing_post = await fetch_post(ctx, post["md5"])
|
|
if existing_post:
|
|
continue
|
|
|
|
# download the post
|
|
post_file_url = post["file_url"]
|
|
post_filename = Path(urlparse(post_file_url).path).name
|
|
post_filepath = DOWNLOADS / post_filename
|
|
if not post_filepath.exists():
|
|
log.info("downloading %r to %r", post_file_url, post_filepath)
|
|
async with ctx.session.get(post_file_url) as resp:
|
|
assert resp.status == 200
|
|
with post_filepath.open("wb") as fd:
|
|
async for chunk in resp.content.iter_chunked(1024):
|
|
fd.write(chunk)
|
|
|
|
# when it's done successfully, insert it
|
|
await insert_post(ctx, post)
|
|
|
|
|
|
async def fetch_post(ctx, md5) -> Optional[dict]:
|
|
rows = await ctx.db.execute_fetchall("select data from posts where md5 = ?", (md5,))
|
|
if not rows:
|
|
return None
|
|
assert len(rows) == 1
|
|
return json.loads(rows[0][0])
|
|
|
|
|
|
async def insert_post(ctx, post):
|
|
await ctx.db.execute_insert(
|
|
"insert into posts (md5, filepath ,data) values (?, ?,?)",
|
|
(
|
|
post["md5"],
|
|
Path(urlparse(post["file_url"]).path).name,
|
|
json.dumps(post),
|
|
),
|
|
)
|
|
await ctx.db.commit()
|
|
|
|
|
|
async def fight(ctx):
|
|
pass
|
|
|
|
|
|
async def realmain(ctx):
|
|
await ctx.db.executescript(
|
|
"""
|
|
create table if not exists posts (md5 text primary key, filepath text, data text);
|
|
"""
|
|
)
|
|
|
|
try:
|
|
mode = sys.argv[1]
|
|
except IndexError:
|
|
raise Exception("must have mode")
|
|
|
|
if mode == "download_images":
|
|
await download_images(ctx)
|
|
elif mode == "fight":
|
|
await fight(ctx)
|
|
else:
|
|
raise AssertionError(f"invalid mode {mode}")
|
|
|
|
|
|
async def main():
|
|
CONFIG_PATH = Path.cwd() / "config.json"
|
|
log.info("hewo")
|
|
async with aiosqlite.connect("./data.db") as db:
|
|
async with aiohttp.ClientSession() as session:
|
|
with CONFIG_PATH.open() as config_fd:
|
|
config_json = json.load(config_fd)
|
|
|
|
config = Config(**config_json)
|
|
ctx = Context(db, session, config)
|
|
await realmain(ctx)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
asyncio.run(main())
|