diff --git a/main.py b/main.py index b87bbc7..7bbf35e 100644 --- a/main.py +++ b/main.py @@ -37,8 +37,6 @@ DEFAULTS = [ # broken model: "mld-tresnetd.6-30000", ] -RATING = ["general", "explicit", "questionable", "sensitive", "safe"] - @dataclass class InterrogatorPost: @@ -50,8 +48,6 @@ class InterrogatorPost: class Interrogator: model_id: str address: str - threshold: float = 0.55 - _fucked_rating: bool = False def _process(self, lst): return lst @@ -72,16 +68,11 @@ class DDInterrogator(Interrogator): new_lst = [] for tag in lst: if tag.startswith("rating:"): - continue + original_danbooru_tag = tag.split(":")[1] else: original_danbooru_tag = tag - if original_danbooru_tag == "safe": - continue - - if original_danbooru_tag in RATING: - continue - + original_danbooru_tag = "general" new_lst.append(original_danbooru_tag) return new_lst @@ -89,7 +80,7 @@ class DDInterrogator(Interrogator): async with ctx.session.post( f"{self.address}/", params={ - "threshold": "0.55", + "threshold": "0.7", }, headers={"Authorization": "Bearer sex"}, data={"file": path.open("rb")}, @@ -101,32 +92,19 @@ class DDInterrogator(Interrogator): class SDInterrogator(Interrogator): - def _process(self, lst): - new_lst = [] - for tag in lst: - if tag.startswith("rating_"): - continue - elif tag in RATING: - continue - else: - original_danbooru_tag = tag - new_lst.append(original_danbooru_tag) - return new_lst - async def interrogate(self, ctx, path): async with aiofiles.open(path, "rb") as fd: as_base64 = base64.b64encode(await fd.read()).decode("utf-8") - url = f"{self.address}/tagger/v1/interrogate" async with ctx.session.post( - url, + f"{self.address}/tagger/v1/interrogate", json={ "model": self.model_id, - "threshold": self.threshold, + "threshold": 0.7, "image": as_base64, }, ) as resp: - log.info("%s got %d from %s", path, resp.status, url) + log.info("got %d", resp.status) assert resp.status == 200 data = await resp.json() tags = [] @@ -145,27 +123,11 @@ class SDInterrogator(Interrogator): return " ".join(upstream_tags) -def tag_string_for(post: dict) -> str: - return ( - post["tag_string_general"] - + " " - + post["tag_string_copyright"] - + " " - + post["tag_string_character"] - ) - - class ControlInterrogator(Interrogator): - async def fetch(self, ctx, path): - md5_hash = Path(path).stem - post = await fetch_post(ctx, md5_hash) - tag_string = tag_string_for(post) - return InterrogatorPost(tag_string.split(), 0) - async def interrogate(self, ctx, path): md5_hash = Path(path).stem post = await fetch_post(ctx, md5_hash) - return tag_string_for(post) + return post["tag_string"] @dataclass @@ -174,8 +136,6 @@ class Config: dd_address: str dd_model_name: str sd_webui_extras: Dict[str, str] - camie_address: str - joytag_address: str sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS)) @property @@ -189,11 +149,7 @@ class Config: SDInterrogator(sd_interrogator, url) for sd_interrogator, url in self.sd_webui_extras.items() ] - + [ - DDInterrogator(self.dd_model_name, self.dd_address), - SDInterrogator("camie-tagger-v1", self.camie_address, 0.5, True), - SDInterrogator("joytag-v1", self.joytag_address, 0.5), - ] + + [DDInterrogator(self.dd_model_name, self.dd_address)] + [ControlInterrogator("control", None)] ) @@ -227,13 +183,7 @@ class Danbooru(Booru): async def posts(self, tag_query: str, limit, page: int): log.info("%s: submit %r", self.title, tag_query) async with self.limiter: - log.info( - "%s: submit upstream query=%r limit=%r page=%r", - self.title, - tag_query, - limit, - page, - ) + log.info("%s: submit upstream %r", self.title, tag_query) async with self.session.get( f"{self.base_url}/posts.json", params={"tags": tag_query, "limit": limit, "page": page}, @@ -317,7 +267,19 @@ async def fetch_post(ctx, md5) -> Optional[dict]: return None assert len(rows) == 1 post = json.loads(rows[0][0]) - post["tag_string"] = tag_string_for(post) + post_rating = post["rating"] + match post_rating: + case "g": + rating_tag = "general" + case "s": + rating_tag = "sensitive" + case "q": + rating_tag = "questionable" + case "e": + rating_tag = "explicit" + case _: + raise AssertionError("invalid post rating {post_rating!r}") + post["tag_string"] = post["tag_string"] + " " + rating_tag return post @@ -343,22 +305,6 @@ async def insert_interrogated_result( await ctx.db.commit() -async def process_hash(ctx, interrogator, missing_hash, semaphore, index, total): - async with semaphore: - log.info("interrogating %r (%d/%d)", missing_hash, index, total) - post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*")) - - start_ts = time.monotonic() - tag_string = await interrogator.interrogate(ctx, post_filepath) - end_ts = time.monotonic() - time_taken = round(end_ts - start_ts, 10) - - log.info("took %.5fsec, got %r", time_taken, tag_string) - await insert_interrogated_result( - ctx, interrogator, missing_hash, tag_string, time_taken - ) - - async def fight(ctx): interrogators = ctx.config.all_available_models @@ -366,7 +312,7 @@ async def fight(ctx): all_hashes = set(r[0] for r in all_rows) for interrogator in interrogators: - log.info("processing fight for %r", interrogator) + log.info("processing for %r", interrogator) # calculate set of images we didn't interrogate yet interrogated_rows = await ctx.db.execute_fetchall( "select md5 from interrogated_posts where model_name = ?", @@ -377,63 +323,34 @@ async def fight(ctx): log.info("missing %d hashes", len(missing_hashes)) - semaphore = asyncio.Semaphore(3) - - tasks = [] for index, missing_hash in enumerate(missing_hashes): - task = process_hash( - ctx, - interrogator, - missing_hash, - semaphore, - index + 1, - len(missing_hashes), + log.info( + "interrogating %r (%d/%d)", missing_hash, index, len(missing_hashes) ) - tasks.append(task) + post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*")) - # Run all tasks concurrently with semaphore limiting to 3 at a time - await asyncio.gather(*tasks) + start_ts = time.monotonic() + tag_string = await interrogator.interrogate(ctx, post_filepath) + end_ts = time.monotonic() + time_taken = round(end_ts - start_ts, 10) + + log.info("took %.5fsec, got %r", time_taken, tag_string) + await insert_interrogated_result( + ctx, interrogator, missing_hash, tag_string, time_taken + ) def score( danbooru_tags: Set[str], interrogator_tags: Set[str] ) -> Tuple[decimal.Decimal, Set[str]]: - - f1 = None - # Handle edge cases - if len(danbooru_tags) == 0 and len(interrogator_tags) == 0: - f1 = decimal.Decimal("1.0") # Both empty means perfect match - - if len(danbooru_tags) == 0 or len(interrogator_tags) == 0: - f1 = decimal.Decimal("0.0") # One empty means no match - - # Calculate true positives (tags that appear in both sets) - true_positives = decimal.Decimal(len(danbooru_tags.intersection(interrogator_tags))) - - # Calculate precision: TP / (TP + FP) - precision = ( - true_positives / len(interrogator_tags) - if len(interrogator_tags) > 0 - else decimal.Decimal("0.0") - ) - - # Calculate recall: TP / (TP + FN) - recall = ( - true_positives / len(danbooru_tags) - if len(danbooru_tags) > 0 - else decimal.Decimal("0.0") - ) - print("recall", recall) - - # Handle the case where both precision and recall are 0 - if f1 is None and precision == 0 and recall == 0: - f1 = decimal.Decimal("0.0") - else: - f1 = decimal.Decimal("2.0") * (precision * recall) / (precision + recall) - + tags_in_danbooru = danbooru_tags.intersection(interrogator_tags) tags_not_in_danbooru = interrogator_tags - danbooru_tags return ( - round(f1, 10), + round( + decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru)) + / decimal.Decimal(len(danbooru_tags)), + 10, + ), tags_not_in_danbooru, ) @@ -448,10 +365,9 @@ async def scores(ctx): model_scores = defaultdict(dict) runtimes = defaultdict(list) incorrect_tags_counters = defaultdict(Counter) - predicted_tags_counter = defaultdict(int) for md5_hash in all_hashes: - log.info("processing score for %r", md5_hash) + log.info("processing for %r", md5_hash) post = await fetch_post(ctx, md5_hash) danbooru_tags = set(post["tag_string"].split()) for interrogator in interrogators: @@ -462,14 +378,9 @@ async def scores(ctx): for tag in incorrect_tags: incorrect_tags_counters[interrogator.model_id][tag] += 1 - log.info(f"{interrogator.model_id} {tagging_score}") - predicted_tags_counter[interrogator.model_id] += len(interrogator_tags) - correct_tags = interrogator_tags.intersection(danbooru_tags) model_scores[interrogator.model_id][md5_hash] = { "score": tagging_score, - "predicted_tags": interrogator_tags, "incorrect_tags": incorrect_tags, - "correct_tags": correct_tags, } summed_scores = { @@ -514,19 +425,8 @@ async def scores(ctx): print(data["score"], end=",") print("]") - total_incorrect = 0 - for _, c in incorrect_tags_counters[model].most_common(10000000): - total_incorrect += c - print( - "most incorrect tags from", - total_incorrect, - "incorrect tags", - "predicted", - predicted_tags_counter[model], - "tags", - ) - for t, c in incorrect_tags_counters[model].most_common(7): - print("\t", t, c) + print("most incorrect tags", incorrect_tags_counters[model].most_common(5)) + PLOTS = Path.cwd() / "plots" PLOTS.mkdir(exist_ok=True) @@ -559,12 +459,7 @@ async def scores(ctx): log.info("plotting positive histogram...") plot2(PLOTS / "positive_score_histogram.png", normalized_scores, model_scores) log.info("plotting error rates...") - plot3( - PLOTS / "error_rate.png", - PLOTS / "score_avg.png", - normalized_scores, - model_scores, - ) + plot3(PLOTS / "error_rate.png", normalized_scores, model_scores) def plot2(output_path, normalized_scores, model_scores): @@ -595,13 +490,12 @@ def plot2(output_path, normalized_scores, model_scores): pio.write_image(fig, output_path, width=1024, height=800) -def plot3(output_path, output_score_avg_path, normalized_scores, model_scores): +def plot3(output_path, normalized_scores, model_scores): data_for_df = { "model": [], - "score_avg": [], - "predicted": [], - "correct": [], - "incorrect": [], + "errors": [], + "rating_errors": [], + "practical_errors": [], } for model in sorted( @@ -609,34 +503,32 @@ def plot3(output_path, output_score_avg_path, normalized_scores, model_scores): key=lambda model: normalized_scores[model], reverse=True, ): - total_predicted_tags, total_incorrect_tags, total_correct_tags = 0, 0, 0 + total_incorrect_tags = 0 + total_rating_errors = 0 for score_data in model_scores[model].values(): - total_predicted_tags += len(score_data["predicted_tags"]) total_incorrect_tags += len(score_data["incorrect_tags"]) - total_correct_tags += len(score_data["correct_tags"]) + total_rating_errors += sum( + 1 + for rating in ["general", "sensitive", "questionable", "explicit"] + if rating in score_data["incorrect_tags"] + ) + practical_absolute_error = total_incorrect_tags - total_rating_errors - data_for_df["score_avg"].append(normalized_scores[model]) - data_for_df["predicted"].append(total_predicted_tags) - data_for_df["incorrect"].append(total_incorrect_tags) - data_for_df["correct"].append(total_correct_tags) + data_for_df["errors"].append(total_incorrect_tags) + data_for_df["rating_errors"].append(total_rating_errors) + data_for_df["practical_errors"].append(practical_absolute_error) data_for_df["model"].append(model) df = pd.DataFrame(data_for_df) fig = go.Figure( data=[ - go.Bar(name="predicted tags", x=df.model, y=df.predicted), - go.Bar(name="incorrect tags", x=df.model, y=df.incorrect), - go.Bar(name="correct tags", x=df.model, y=df.correct), + go.Bar(name="incorrect tags", x=df.model, y=df.errors), + go.Bar(name="incorrect ratings", x=df.model, y=df.rating_errors), + go.Bar(name="practical error", x=df.model, y=df.practical_errors), ] ) pio.write_image(fig, output_path, width=1024, height=800) - fig2 = go.Figure( - data=[ - go.Bar(name="score avg", x=df.model, y=df.score_avg), - ] - ) - pio.write_image(fig2, output_score_avg_path, width=1024, height=800) async def realmain(ctx): diff --git a/requirements.txt b/requirements.txt index dba5229..cb2f5b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -aiosqlite -aiohttp -aiolimiter>1.1.0 -aiofiles -plotly>5.15.0 -pandas -kaleido +aiosqlite==0.20.0 +aiohttp==3.10.0 +aiolimiter>1.1.0<2.0 +aiofiles==24.1.0 +plotly>5.15.0<6.0 +pandas==2.2.2 +kaleido==0.2.1