expose averaage runtime to report
This commit is contained in:
parent
4ab35b285e
commit
f53ae5779a
22
main.py
22
main.py
|
@ -30,6 +30,12 @@ DEFAULTS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InterrogatorPost:
|
||||||
|
tags: List[str]
|
||||||
|
time_taken: float
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Interrogator:
|
class Interrogator:
|
||||||
model_id: str
|
model_id: str
|
||||||
|
@ -38,14 +44,14 @@ class Interrogator:
|
||||||
def _process(self, lst):
|
def _process(self, lst):
|
||||||
return lst
|
return lst
|
||||||
|
|
||||||
async def fetch_tags(self, ctx, md5_hash):
|
async def fetch(self, ctx, md5_hash):
|
||||||
rows = await ctx.db.execute_fetchall(
|
rows = await ctx.db.execute_fetchall(
|
||||||
"select output_tag_string from interrogated_posts where md5 = ? and model_name = ?",
|
"select output_tag_string, time_taken from interrogated_posts where md5 = ? and model_name = ?",
|
||||||
(md5_hash, self.model_id),
|
(md5_hash, self.model_id),
|
||||||
)
|
)
|
||||||
assert len(rows) == 1 # run 'fight' mode, there's missing posts
|
assert len(rows) == 1 # run 'fight' mode, there's missing posts
|
||||||
tag_string = rows[0][0]
|
tag_string, time_taken = rows[0][0], rows[0][1]
|
||||||
return self._process(tag_string.split())
|
return InterrogatorPost(self._process(tag_string.split()), time_taken)
|
||||||
|
|
||||||
|
|
||||||
class DDInterrogator(Interrogator):
|
class DDInterrogator(Interrogator):
|
||||||
|
@ -318,6 +324,7 @@ async def scores(ctx):
|
||||||
|
|
||||||
# absolute_scores = defaultdict(decimal.Decimal)
|
# absolute_scores = defaultdict(decimal.Decimal)
|
||||||
model_scores = defaultdict(dict)
|
model_scores = defaultdict(dict)
|
||||||
|
runtimes = defaultdict(list)
|
||||||
incorrect_tags_counters = defaultdict(Counter)
|
incorrect_tags_counters = defaultdict(Counter)
|
||||||
|
|
||||||
for md5_hash in all_hashes:
|
for md5_hash in all_hashes:
|
||||||
|
@ -325,7 +332,9 @@ async def scores(ctx):
|
||||||
post = await fetch_post(ctx, md5_hash)
|
post = await fetch_post(ctx, md5_hash)
|
||||||
danbooru_tags = set(post["tag_string"].split())
|
danbooru_tags = set(post["tag_string"].split())
|
||||||
for interrogator in interrogators:
|
for interrogator in interrogators:
|
||||||
interrogator_tags = set(await interrogator.fetch_tags(ctx, md5_hash))
|
post_data = await interrogator.fetch(ctx, md5_hash)
|
||||||
|
runtimes[interrogator.model_id].append(post_data.time_taken)
|
||||||
|
interrogator_tags = set(post_data.tags)
|
||||||
tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags)
|
tagging_score, incorrect_tags = score(danbooru_tags, interrogator_tags)
|
||||||
for tag in incorrect_tags:
|
for tag in incorrect_tags:
|
||||||
incorrect_tags_counters[interrogator.model_id][tag] += 1
|
incorrect_tags_counters[interrogator.model_id][tag] += 1
|
||||||
|
@ -352,7 +361,8 @@ async def scores(ctx):
|
||||||
key=lambda model: normalized_scores[model],
|
key=lambda model: normalized_scores[model],
|
||||||
reverse=True,
|
reverse=True,
|
||||||
):
|
):
|
||||||
print(model, normalized_scores[model])
|
average_runtime = sum(runtimes[model]) / len(runtimes[model])
|
||||||
|
print(model, normalized_scores[model], "runtime", average_runtime, "sec")
|
||||||
if os.getenv("SHOWOFF", "0") == "1":
|
if os.getenv("SHOWOFF", "0") == "1":
|
||||||
print("[", end="")
|
print("[", end="")
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue