diff --git a/main.py b/main.py index 7bbf35e..b87bbc7 100644 --- a/main.py +++ b/main.py @@ -37,6 +37,8 @@ DEFAULTS = [ # broken model: "mld-tresnetd.6-30000", ] +RATING = ["general", "explicit", "questionable", "sensitive", "safe"] + @dataclass class InterrogatorPost: @@ -48,6 +50,8 @@ class InterrogatorPost: class Interrogator: model_id: str address: str + threshold: float = 0.55 + _fucked_rating: bool = False def _process(self, lst): return lst @@ -68,11 +72,16 @@ class DDInterrogator(Interrogator): new_lst = [] for tag in lst: if tag.startswith("rating:"): - original_danbooru_tag = tag.split(":")[1] + continue else: original_danbooru_tag = tag + if original_danbooru_tag == "safe": - original_danbooru_tag = "general" + continue + + if original_danbooru_tag in RATING: + continue + new_lst.append(original_danbooru_tag) return new_lst @@ -80,7 +89,7 @@ class DDInterrogator(Interrogator): async with ctx.session.post( f"{self.address}/", params={ - "threshold": "0.7", + "threshold": "0.55", }, headers={"Authorization": "Bearer sex"}, data={"file": path.open("rb")}, @@ -92,19 +101,32 @@ class DDInterrogator(Interrogator): class SDInterrogator(Interrogator): + def _process(self, lst): + new_lst = [] + for tag in lst: + if tag.startswith("rating_"): + continue + elif tag in RATING: + continue + else: + original_danbooru_tag = tag + new_lst.append(original_danbooru_tag) + return new_lst + async def interrogate(self, ctx, path): async with aiofiles.open(path, "rb") as fd: as_base64 = base64.b64encode(await fd.read()).decode("utf-8") + url = f"{self.address}/tagger/v1/interrogate" async with ctx.session.post( - f"{self.address}/tagger/v1/interrogate", + url, json={ "model": self.model_id, - "threshold": 0.7, + "threshold": self.threshold, "image": as_base64, }, ) as resp: - log.info("got %d", resp.status) + log.info("%s got %d from %s", path, resp.status, url) assert resp.status == 200 data = await resp.json() tags = [] @@ -123,11 +145,27 @@ class SDInterrogator(Interrogator): return " ".join(upstream_tags) +def tag_string_for(post: dict) -> str: + return ( + post["tag_string_general"] + + " " + + post["tag_string_copyright"] + + " " + + post["tag_string_character"] + ) + + class ControlInterrogator(Interrogator): + async def fetch(self, ctx, path): + md5_hash = Path(path).stem + post = await fetch_post(ctx, md5_hash) + tag_string = tag_string_for(post) + return InterrogatorPost(tag_string.split(), 0) + async def interrogate(self, ctx, path): md5_hash = Path(path).stem post = await fetch_post(ctx, md5_hash) - return post["tag_string"] + return tag_string_for(post) @dataclass @@ -136,6 +174,8 @@ class Config: dd_address: str dd_model_name: str sd_webui_extras: Dict[str, str] + camie_address: str + joytag_address: str sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS)) @property @@ -149,7 +189,11 @@ class Config: SDInterrogator(sd_interrogator, url) for sd_interrogator, url in self.sd_webui_extras.items() ] - + [DDInterrogator(self.dd_model_name, self.dd_address)] + + [ + DDInterrogator(self.dd_model_name, self.dd_address), + SDInterrogator("camie-tagger-v1", self.camie_address, 0.5, True), + SDInterrogator("joytag-v1", self.joytag_address, 0.5), + ] + [ControlInterrogator("control", None)] ) @@ -183,7 +227,13 @@ class Danbooru(Booru): async def posts(self, tag_query: str, limit, page: int): log.info("%s: submit %r", self.title, tag_query) async with self.limiter: - log.info("%s: submit upstream %r", self.title, tag_query) + log.info( + "%s: submit upstream query=%r limit=%r page=%r", + self.title, + tag_query, + limit, + page, + ) async with self.session.get( f"{self.base_url}/posts.json", params={"tags": tag_query, "limit": limit, "page": page}, @@ -267,19 +317,7 @@ async def fetch_post(ctx, md5) -> Optional[dict]: return None assert len(rows) == 1 post = json.loads(rows[0][0]) - post_rating = post["rating"] - match post_rating: - case "g": - rating_tag = "general" - case "s": - rating_tag = "sensitive" - case "q": - rating_tag = "questionable" - case "e": - rating_tag = "explicit" - case _: - raise AssertionError("invalid post rating {post_rating!r}") - post["tag_string"] = post["tag_string"] + " " + rating_tag + post["tag_string"] = tag_string_for(post) return post @@ -305,6 +343,22 @@ async def insert_interrogated_result( await ctx.db.commit() +async def process_hash(ctx, interrogator, missing_hash, semaphore, index, total): + async with semaphore: + log.info("interrogating %r (%d/%d)", missing_hash, index, total) + post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*")) + + start_ts = time.monotonic() + tag_string = await interrogator.interrogate(ctx, post_filepath) + end_ts = time.monotonic() + time_taken = round(end_ts - start_ts, 10) + + log.info("took %.5fsec, got %r", time_taken, tag_string) + await insert_interrogated_result( + ctx, interrogator, missing_hash, tag_string, time_taken + ) + + async def fight(ctx): interrogators = ctx.config.all_available_models @@ -312,7 +366,7 @@ async def fight(ctx): all_hashes = set(r[0] for r in all_rows) for interrogator in interrogators: - log.info("processing for %r", interrogator) + log.info("processing fight for %r", interrogator) # calculate set of images we didn't interrogate yet interrogated_rows = await ctx.db.execute_fetchall( "select md5 from interrogated_posts where model_name = ?", @@ -323,34 +377,63 @@ async def fight(ctx): log.info("missing %d hashes", len(missing_hashes)) + semaphore = asyncio.Semaphore(3) + + tasks = [] for index, missing_hash in enumerate(missing_hashes): - log.info( - "interrogating %r (%d/%d)", missing_hash, index, len(missing_hashes) + task = process_hash( + ctx, + interrogator, + missing_hash, + semaphore, + index + 1, + len(missing_hashes), ) - post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*")) + tasks.append(task) - start_ts = time.monotonic() - tag_string = await interrogator.interrogate(ctx, post_filepath) - end_ts = time.monotonic() - time_taken = round(end_ts - start_ts, 10) - - log.info("took %.5fsec, got %r", time_taken, tag_string) - await insert_interrogated_result( - ctx, interrogator, missing_hash, tag_string, time_taken - ) + # Run all tasks concurrently with semaphore limiting to 3 at a time + await asyncio.gather(*tasks) def score( danbooru_tags: Set[str], interrogator_tags: Set[str] ) -> Tuple[decimal.Decimal, Set[str]]: - tags_in_danbooru = danbooru_tags.intersection(interrogator_tags) + + f1 = None + # Handle edge cases + if len(danbooru_tags) == 0 and len(interrogator_tags) == 0: + f1 = decimal.Decimal("1.0") # Both empty means perfect match + + if len(danbooru_tags) == 0 or len(interrogator_tags) == 0: + f1 = decimal.Decimal("0.0") # One empty means no match + + # Calculate true positives (tags that appear in both sets) + true_positives = decimal.Decimal(len(danbooru_tags.intersection(interrogator_tags))) + + # Calculate precision: TP / (TP + FP) + precision = ( + true_positives / len(interrogator_tags) + if len(interrogator_tags) > 0 + else decimal.Decimal("0.0") + ) + + # Calculate recall: TP / (TP + FN) + recall = ( + true_positives / len(danbooru_tags) + if len(danbooru_tags) > 0 + else decimal.Decimal("0.0") + ) + print("recall", recall) + + # Handle the case where both precision and recall are 0 + if f1 is None and precision == 0 and recall == 0: + f1 = decimal.Decimal("0.0") + else: + f1 = decimal.Decimal("2.0") * (precision * recall) / (precision + recall) + tags_not_in_danbooru = interrogator_tags - danbooru_tags return ( - round( - decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru)) - / decimal.Decimal(len(danbooru_tags)), - 10, - ), + round(f1, 10), tags_not_in_danbooru, ) @@ -365,9 +448,10 @@ async def scores(ctx): model_scores = defaultdict(dict) runtimes = defaultdict(list) incorrect_tags_counters = defaultdict(Counter) + predicted_tags_counter = defaultdict(int) for md5_hash in all_hashes: - log.info("processing for %r", md5_hash) + log.info("processing score for %r", md5_hash) post = await fetch_post(ctx, md5_hash) danbooru_tags = set(post["tag_string"].split()) for interrogator in interrogators: @@ -378,9 +462,14 @@ async def scores(ctx): for tag in incorrect_tags: incorrect_tags_counters[interrogator.model_id][tag] += 1 + log.info(f"{interrogator.model_id} {tagging_score}") + predicted_tags_counter[interrogator.model_id] += len(interrogator_tags) + correct_tags = interrogator_tags.intersection(danbooru_tags) model_scores[interrogator.model_id][md5_hash] = { "score": tagging_score, + "predicted_tags": interrogator_tags, "incorrect_tags": incorrect_tags, + "correct_tags": correct_tags, } summed_scores = { @@ -425,8 +514,19 @@ async def scores(ctx): print(data["score"], end=",") print("]") - print("most incorrect tags", incorrect_tags_counters[model].most_common(5)) - + total_incorrect = 0 + for _, c in incorrect_tags_counters[model].most_common(10000000): + total_incorrect += c + print( + "most incorrect tags from", + total_incorrect, + "incorrect tags", + "predicted", + predicted_tags_counter[model], + "tags", + ) + for t, c in incorrect_tags_counters[model].most_common(7): + print("\t", t, c) PLOTS = Path.cwd() / "plots" PLOTS.mkdir(exist_ok=True) @@ -459,7 +559,12 @@ async def scores(ctx): log.info("plotting positive histogram...") plot2(PLOTS / "positive_score_histogram.png", normalized_scores, model_scores) log.info("plotting error rates...") - plot3(PLOTS / "error_rate.png", normalized_scores, model_scores) + plot3( + PLOTS / "error_rate.png", + PLOTS / "score_avg.png", + normalized_scores, + model_scores, + ) def plot2(output_path, normalized_scores, model_scores): @@ -490,12 +595,13 @@ def plot2(output_path, normalized_scores, model_scores): pio.write_image(fig, output_path, width=1024, height=800) -def plot3(output_path, normalized_scores, model_scores): +def plot3(output_path, output_score_avg_path, normalized_scores, model_scores): data_for_df = { "model": [], - "errors": [], - "rating_errors": [], - "practical_errors": [], + "score_avg": [], + "predicted": [], + "correct": [], + "incorrect": [], } for model in sorted( @@ -503,32 +609,34 @@ def plot3(output_path, normalized_scores, model_scores): key=lambda model: normalized_scores[model], reverse=True, ): - total_incorrect_tags = 0 - total_rating_errors = 0 + total_predicted_tags, total_incorrect_tags, total_correct_tags = 0, 0, 0 for score_data in model_scores[model].values(): + total_predicted_tags += len(score_data["predicted_tags"]) total_incorrect_tags += len(score_data["incorrect_tags"]) - total_rating_errors += sum( - 1 - for rating in ["general", "sensitive", "questionable", "explicit"] - if rating in score_data["incorrect_tags"] - ) - practical_absolute_error = total_incorrect_tags - total_rating_errors + total_correct_tags += len(score_data["correct_tags"]) - data_for_df["errors"].append(total_incorrect_tags) - data_for_df["rating_errors"].append(total_rating_errors) - data_for_df["practical_errors"].append(practical_absolute_error) + data_for_df["score_avg"].append(normalized_scores[model]) + data_for_df["predicted"].append(total_predicted_tags) + data_for_df["incorrect"].append(total_incorrect_tags) + data_for_df["correct"].append(total_correct_tags) data_for_df["model"].append(model) df = pd.DataFrame(data_for_df) fig = go.Figure( data=[ - go.Bar(name="incorrect tags", x=df.model, y=df.errors), - go.Bar(name="incorrect ratings", x=df.model, y=df.rating_errors), - go.Bar(name="practical error", x=df.model, y=df.practical_errors), + go.Bar(name="predicted tags", x=df.model, y=df.predicted), + go.Bar(name="incorrect tags", x=df.model, y=df.incorrect), + go.Bar(name="correct tags", x=df.model, y=df.correct), ] ) pio.write_image(fig, output_path, width=1024, height=800) + fig2 = go.Figure( + data=[ + go.Bar(name="score avg", x=df.model, y=df.score_avg), + ] + ) + pio.write_image(fig2, output_score_avg_path, width=1024, height=800) async def realmain(ctx): diff --git a/requirements.txt b/requirements.txt index cb2f5b8..dba5229 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -aiosqlite==0.20.0 -aiohttp==3.10.0 -aiolimiter>1.1.0<2.0 -aiofiles==24.1.0 -plotly>5.15.0<6.0 -pandas==2.2.2 -kaleido==0.2.1 +aiosqlite +aiohttp +aiolimiter>1.1.0 +aiofiles +plotly>5.15.0 +pandas +kaleido