From f53ae5779a4e5aa3fb36f834cd3561d116453429 Mon Sep 17 00:00:00 2001 From: Luna Date: Sat, 10 Jun 2023 03:22:20 -0300 Subject: [PATCH] expose averaage runtime to report --- main.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/main.py b/main.py index 1d817ed..7dfee96 100644 --- a/main.py +++ b/main.py @@ -30,6 +30,12 @@ DEFAULTS = [ ] +@dataclass +class InterrogatorPost: + tags: List[str] + time_taken: float + + @dataclass class Interrogator: model_id: str @@ -38,14 +44,14 @@ class Interrogator: def _process(self, lst): return lst - async def fetch_tags(self, ctx, md5_hash): + async def fetch(self, ctx, md5_hash): rows = await ctx.db.execute_fetchall( - "select output_tag_string from interrogated_posts where md5 = ? and model_name = ?", + "select output_tag_string, time_taken from interrogated_posts where md5 = ? and model_name = ?", (md5_hash, self.model_id), ) assert len(rows) == 1 # run 'fight' mode, there's missing posts - tag_string = rows[0][0] - return self._process(tag_string.split()) + tag_string, time_taken = rows[0][0], rows[0][1] + return InterrogatorPost(self._process(tag_string.split()), time_taken) class DDInterrogator(Interrogator): @@ -318,6 +324,7 @@ async def scores(ctx): # absolute_scores = defaultdict(decimal.Decimal) model_scores = defaultdict(dict) + runtimes = defaultdict(list) incorrect_tags_counters = defaultdict(Counter) for md5_hash in all_hashes: @@ -325,7 +332,9 @@ async def scores(ctx): post = await fetch_post(ctx, md5_hash) danbooru_tags = set(post["tag_string"].split()) for interrogator in interrogators: - interrogator_tags = set(await interrogator.fetch_tags(ctx, md5_hash)) + post_data = await interrogator.fetch(ctx, md5_hash) + runtimes[interrogator.model_id].append(post_data.time_taken) + interrogator_tags = set(post_data.tags) tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags) for tag in incorrect_tags: incorrect_tags_counters[interrogator.model_id][tag] += 1 @@ -352,7 +361,8 @@ async def scores(ctx): key=lambda model: normalized_scores[model], reverse=True, ): - print(model, normalized_scores[model]) + average_runtime = sum(runtimes[model]) / len(runtimes[model]) + print(model, normalized_scores[model], "runtime", average_runtime, "sec") if os.getenv("SHOWOFF", "0") == "1": print("[", end="")