expose averaage runtime to report

This commit is contained in:
Luna 2023-06-10 03:22:20 -03:00
parent 4ab35b285e
commit f53ae5779a
1 changed files with 16 additions and 6 deletions

22
main.py
View File

@ -30,6 +30,12 @@ DEFAULTS = [
]
@dataclass
class InterrogatorPost:
tags: List[str]
time_taken: float
@dataclass
class Interrogator:
model_id: str
@ -38,14 +44,14 @@ class Interrogator:
def _process(self, lst):
return lst
async def fetch_tags(self, ctx, md5_hash):
async def fetch(self, ctx, md5_hash):
rows = await ctx.db.execute_fetchall(
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
"select output_tag_string, time_taken from interrogated_posts where md5 = ? and model_name = ?",
(md5_hash, self.model_id),
)
assert len(rows) == 1 # run 'fight' mode, there's missing posts
tag_string = rows[0][0]
return self._process(tag_string.split())
tag_string, time_taken = rows[0][0], rows[0][1]
return InterrogatorPost(self._process(tag_string.split()), time_taken)
class DDInterrogator(Interrogator):
@ -318,6 +324,7 @@ async def scores(ctx):
# absolute_scores = defaultdict(decimal.Decimal)
model_scores = defaultdict(dict)
runtimes = defaultdict(list)
incorrect_tags_counters = defaultdict(Counter)
for md5_hash in all_hashes:
@ -325,7 +332,9 @@ async def scores(ctx):
post = await fetch_post(ctx, md5_hash)
danbooru_tags = set(post["tag_string"].split())
for interrogator in interrogators:
interrogator_tags = set(await interrogator.fetch_tags(ctx, md5_hash))
post_data = await interrogator.fetch(ctx, md5_hash)
runtimes[interrogator.model_id].append(post_data.time_taken)
interrogator_tags = set(post_data.tags)
tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags)
for tag in incorrect_tags:
incorrect_tags_counters[interrogator.model_id][tag] += 1
@ -352,7 +361,8 @@ async def scores(ctx):
key=lambda model: normalized_scores[model],
reverse=True,
):
print(model, normalized_scores[model])
average_runtime = sum(runtimes[model]) / len(runtimes[model])
print(model, normalized_scores[model], "runtime", average_runtime, "sec")
if os.getenv("SHOWOFF", "0") == "1":
print("[", end="")