From 919bd4017ee1edcd89ab0f9156f7b3eee53fb5fe Mon Sep 17 00:00:00 2001 From: Luna Date: Sat, 10 Jun 2023 01:30:26 -0300 Subject: [PATCH] fix deepdanbooru's rating tags --- main.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index 40a6b06..f9d415c 100644 --- a/main.py +++ b/main.py @@ -8,10 +8,10 @@ import asyncio import aiosqlite import base64 import logging -from collections import defaultdict +from collections import defaultdict, Counter from pathlib import Path from urllib.parse import urlparse -from typing import Any, Optional, List, Set +from typing import Any, Optional, List, Set, Tuple from aiolimiter import AsyncLimiter from dataclasses import dataclass, field @@ -33,8 +33,32 @@ class Interrogator: model_id: str address: str + def _process(self, lst): + return lst + + async def fetch_tags(self, ctx, md5_hash): + rows = await ctx.db.execute_fetchall( + "select output_tag_string from interrogated_posts where md5 = ? and model_name = ?", + (md5_hash, self.model_id), + ) + assert len(rows) == 1 # run 'fight' mode, there's missing posts + tag_string = rows[0][0] + return self._process(tag_string.split()) + class DDInterrogator(Interrogator): + def _process(self, lst): + new_lst = [] + for tag in lst: + if tag.startswith("rating:"): + original_danbooru_tag = tag.split(":")[1] + else: + original_danbooru_tag = tag + if original_danbooru_tag == "safe": + continue + new_lst.append(original_danbooru_tag) + return new_lst + async def interrogate(self, session, path): async with session.post( f"{self.address}/",