fix deepdanbooru's rating tags

This commit is contained in:
Luna 2023-06-10 01:30:26 -03:00
parent 18da5d7972
commit 919bd4017e
1 changed files with 26 additions and 2 deletions

28
main.py
View File

@ -8,10 +8,10 @@ import asyncio
import aiosqlite
import base64
import logging
from collections import defaultdict
from collections import defaultdict, Counter
from pathlib import Path
from urllib.parse import urlparse
from typing import Any, Optional, List, Set
from typing import Any, Optional, List, Set, Tuple
from aiolimiter import AsyncLimiter
from dataclasses import dataclass, field
@ -33,8 +33,32 @@ class Interrogator:
model_id: str
address: str
def _process(self, lst):
return lst
async def fetch_tags(self, ctx, md5_hash):
rows = await ctx.db.execute_fetchall(
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
(md5_hash, self.model_id),
)
assert len(rows) == 1 # run 'fight' mode, there's missing posts
tag_string = rows[0][0]
return self._process(tag_string.split())
class DDInterrogator(Interrogator):
def _process(self, lst):
new_lst = []
for tag in lst:
if tag.startswith("rating:"):
original_danbooru_tag = tag.split(":")[1]
else:
original_danbooru_tag = tag
if original_danbooru_tag == "safe":
continue
new_lst.append(original_danbooru_tag)
return new_lst
async def interrogate(self, session, path):
async with session.post(
f"{self.address}/",