expose averaage runtime to report

This commit is contained in:
Luna 2023-06-10 03:22:20 -03:00
parent 4ab35b285e
commit f53ae5779a
1 changed files with 16 additions and 6 deletions

22
main.py
View File

@ -30,6 +30,12 @@ DEFAULTS = [
] ]
@dataclass
class InterrogatorPost:
tags: List[str]
time_taken: float
@dataclass @dataclass
class Interrogator: class Interrogator:
model_id: str model_id: str
@ -38,14 +44,14 @@ class Interrogator:
def _process(self, lst): def _process(self, lst):
return lst return lst
async def fetch_tags(self, ctx, md5_hash): async def fetch(self, ctx, md5_hash):
rows = await ctx.db.execute_fetchall( rows = await ctx.db.execute_fetchall(
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?", "select output_tag_string, time_taken from interrogated_posts where md5 = ? and model_name = ?",
(md5_hash, self.model_id), (md5_hash, self.model_id),
) )
assert len(rows) == 1 # run 'fight' mode, there's missing posts assert len(rows) == 1 # run 'fight' mode, there's missing posts
tag_string = rows[0][0] tag_string, time_taken = rows[0][0], rows[0][1]
return self._process(tag_string.split()) return InterrogatorPost(self._process(tag_string.split()), time_taken)
class DDInterrogator(Interrogator): class DDInterrogator(Interrogator):
@ -318,6 +324,7 @@ async def scores(ctx):
# absolute_scores = defaultdict(decimal.Decimal) # absolute_scores = defaultdict(decimal.Decimal)
model_scores = defaultdict(dict) model_scores = defaultdict(dict)
runtimes = defaultdict(list)
incorrect_tags_counters = defaultdict(Counter) incorrect_tags_counters = defaultdict(Counter)
for md5_hash in all_hashes: for md5_hash in all_hashes:
@ -325,7 +332,9 @@ async def scores(ctx):
post = await fetch_post(ctx, md5_hash) post = await fetch_post(ctx, md5_hash)
danbooru_tags = set(post["tag_string"].split()) danbooru_tags = set(post["tag_string"].split())
for interrogator in interrogators: for interrogator in interrogators:
interrogator_tags = set(await interrogator.fetch_tags(ctx, md5_hash)) post_data = await interrogator.fetch(ctx, md5_hash)
runtimes[interrogator.model_id].append(post_data.time_taken)
interrogator_tags = set(post_data.tags)
tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags) tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags)
for tag in incorrect_tags: for tag in incorrect_tags:
incorrect_tags_counters[interrogator.model_id][tag] += 1 incorrect_tags_counters[interrogator.model_id][tag] += 1
@ -352,7 +361,8 @@ async def scores(ctx):
key=lambda model: normalized_scores[model], key=lambda model: normalized_scores[model],
reverse=True, reverse=True,
): ):
print(model, normalized_scores[model]) average_runtime = sum(runtimes[model]) / len(runtimes[model])
print(model, normalized_scores[model], "runtime", average_runtime, "sec")
if os.getenv("SHOWOFF", "0") == "1": if os.getenv("SHOWOFF", "0") == "1":
print("[", end="") print("[", end="")