expose averaage runtime to report
This commit is contained in:
parent
4ab35b285e
commit
f53ae5779a
1 changed files with 16 additions and 6 deletions
22
main.py
22
main.py
|
@ -30,6 +30,12 @@ DEFAULTS = [
|
|||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterrogatorPost:
|
||||
tags: List[str]
|
||||
time_taken: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class Interrogator:
|
||||
model_id: str
|
||||
|
@ -38,14 +44,14 @@ class Interrogator:
|
|||
def _process(self, lst):
|
||||
return lst
|
||||
|
||||
async def fetch_tags(self, ctx, md5_hash):
|
||||
async def fetch(self, ctx, md5_hash):
|
||||
rows = await ctx.db.execute_fetchall(
|
||||
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
|
||||
"select output_tag_string, time_taken from interrogated_posts where md5 = ? and model_name = ?",
|
||||
(md5_hash, self.model_id),
|
||||
)
|
||||
assert len(rows) == 1 # run 'fight' mode, there's missing posts
|
||||
tag_string = rows[0][0]
|
||||
return self._process(tag_string.split())
|
||||
tag_string, time_taken = rows[0][0], rows[0][1]
|
||||
return InterrogatorPost(self._process(tag_string.split()), time_taken)
|
||||
|
||||
|
||||
class DDInterrogator(Interrogator):
|
||||
|
@ -318,6 +324,7 @@ async def scores(ctx):
|
|||
|
||||
# absolute_scores = defaultdict(decimal.Decimal)
|
||||
model_scores = defaultdict(dict)
|
||||
runtimes = defaultdict(list)
|
||||
incorrect_tags_counters = defaultdict(Counter)
|
||||
|
||||
for md5_hash in all_hashes:
|
||||
|
@ -325,7 +332,9 @@ async def scores(ctx):
|
|||
post = await fetch_post(ctx, md5_hash)
|
||||
danbooru_tags = set(post["tag_string"].split())
|
||||
for interrogator in interrogators:
|
||||
interrogator_tags = set(await interrogator.fetch_tags(ctx, md5_hash))
|
||||
post_data = await interrogator.fetch(ctx, md5_hash)
|
||||
runtimes[interrogator.model_id].append(post_data.time_taken)
|
||||
interrogator_tags = set(post_data.tags)
|
||||
tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags)
|
||||
for tag in incorrect_tags:
|
||||
incorrect_tags_counters[interrogator.model_id][tag] += 1
|
||||
|
@ -352,7 +361,8 @@ async def scores(ctx):
|
|||
key=lambda model: normalized_scores[model],
|
||||
reverse=True,
|
||||
):
|
||||
print(model, normalized_scores[model])
|
||||
average_runtime = sum(runtimes[model]) / len(runtimes[model])
|
||||
print(model, normalized_scores[model], "runtime", average_runtime, "sec")
|
||||
if os.getenv("SHOWOFF", "0") == "1":
|
||||
print("[", end="")
|
||||
|
||||
|
|
Loading…
Reference in a new issue