Compare commits
5 commits
990b69c602
...
cc7af237b4
Author | SHA1 | Date | |
---|---|---|---|
cc7af237b4 | |||
c721e5ee31 | |||
a73a3f26ce | |||
2550320f9d | |||
90695c0310 |
2 changed files with 178 additions and 70 deletions
234
main.py
234
main.py
|
@ -37,6 +37,8 @@ DEFAULTS = [
|
|||
# broken model: "mld-tresnetd.6-30000",
|
||||
]
|
||||
|
||||
RATING = ["general", "explicit", "questionable", "sensitive", "safe"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class InterrogatorPost:
|
||||
|
@ -48,6 +50,8 @@ class InterrogatorPost:
|
|||
class Interrogator:
|
||||
model_id: str
|
||||
address: str
|
||||
threshold: float = 0.55
|
||||
_fucked_rating: bool = False
|
||||
|
||||
def _process(self, lst):
|
||||
return lst
|
||||
|
@ -68,11 +72,16 @@ class DDInterrogator(Interrogator):
|
|||
new_lst = []
|
||||
for tag in lst:
|
||||
if tag.startswith("rating:"):
|
||||
original_danbooru_tag = tag.split(":")[1]
|
||||
continue
|
||||
else:
|
||||
original_danbooru_tag = tag
|
||||
|
||||
if original_danbooru_tag == "safe":
|
||||
original_danbooru_tag = "general"
|
||||
continue
|
||||
|
||||
if original_danbooru_tag in RATING:
|
||||
continue
|
||||
|
||||
new_lst.append(original_danbooru_tag)
|
||||
return new_lst
|
||||
|
||||
|
@ -80,7 +89,7 @@ class DDInterrogator(Interrogator):
|
|||
async with ctx.session.post(
|
||||
f"{self.address}/",
|
||||
params={
|
||||
"threshold": "0.7",
|
||||
"threshold": "0.55",
|
||||
},
|
||||
headers={"Authorization": "Bearer sex"},
|
||||
data={"file": path.open("rb")},
|
||||
|
@ -92,19 +101,32 @@ class DDInterrogator(Interrogator):
|
|||
|
||||
|
||||
class SDInterrogator(Interrogator):
|
||||
def _process(self, lst):
|
||||
new_lst = []
|
||||
for tag in lst:
|
||||
if tag.startswith("rating_"):
|
||||
continue
|
||||
elif tag in RATING:
|
||||
continue
|
||||
else:
|
||||
original_danbooru_tag = tag
|
||||
new_lst.append(original_danbooru_tag)
|
||||
return new_lst
|
||||
|
||||
async def interrogate(self, ctx, path):
|
||||
async with aiofiles.open(path, "rb") as fd:
|
||||
as_base64 = base64.b64encode(await fd.read()).decode("utf-8")
|
||||
|
||||
url = f"{self.address}/tagger/v1/interrogate"
|
||||
async with ctx.session.post(
|
||||
f"{self.address}/tagger/v1/interrogate",
|
||||
url,
|
||||
json={
|
||||
"model": self.model_id,
|
||||
"threshold": 0.7,
|
||||
"threshold": self.threshold,
|
||||
"image": as_base64,
|
||||
},
|
||||
) as resp:
|
||||
log.info("got %d", resp.status)
|
||||
log.info("%s got %d from %s", path, resp.status, url)
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
tags = []
|
||||
|
@ -123,11 +145,27 @@ class SDInterrogator(Interrogator):
|
|||
return " ".join(upstream_tags)
|
||||
|
||||
|
||||
def tag_string_for(post: dict) -> str:
|
||||
return (
|
||||
post["tag_string_general"]
|
||||
+ " "
|
||||
+ post["tag_string_copyright"]
|
||||
+ " "
|
||||
+ post["tag_string_character"]
|
||||
)
|
||||
|
||||
|
||||
class ControlInterrogator(Interrogator):
|
||||
async def fetch(self, ctx, path):
|
||||
md5_hash = Path(path).stem
|
||||
post = await fetch_post(ctx, md5_hash)
|
||||
tag_string = tag_string_for(post)
|
||||
return InterrogatorPost(tag_string.split(), 0)
|
||||
|
||||
async def interrogate(self, ctx, path):
|
||||
md5_hash = Path(path).stem
|
||||
post = await fetch_post(ctx, md5_hash)
|
||||
return post["tag_string"]
|
||||
return tag_string_for(post)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -136,6 +174,8 @@ class Config:
|
|||
dd_address: str
|
||||
dd_model_name: str
|
||||
sd_webui_extras: Dict[str, str]
|
||||
camie_address: str
|
||||
joytag_address: str
|
||||
sd_webui_models: List[str] = field(default_factory=lambda: list(DEFAULTS))
|
||||
|
||||
@property
|
||||
|
@ -149,7 +189,11 @@ class Config:
|
|||
SDInterrogator(sd_interrogator, url)
|
||||
for sd_interrogator, url in self.sd_webui_extras.items()
|
||||
]
|
||||
+ [DDInterrogator(self.dd_model_name, self.dd_address)]
|
||||
+ [
|
||||
DDInterrogator(self.dd_model_name, self.dd_address),
|
||||
SDInterrogator("camie-tagger-v1", self.camie_address, 0.5, True),
|
||||
SDInterrogator("joytag-v1", self.joytag_address, 0.5),
|
||||
]
|
||||
+ [ControlInterrogator("control", None)]
|
||||
)
|
||||
|
||||
|
@ -183,7 +227,13 @@ class Danbooru(Booru):
|
|||
async def posts(self, tag_query: str, limit, page: int):
|
||||
log.info("%s: submit %r", self.title, tag_query)
|
||||
async with self.limiter:
|
||||
log.info("%s: submit upstream %r", self.title, tag_query)
|
||||
log.info(
|
||||
"%s: submit upstream query=%r limit=%r page=%r",
|
||||
self.title,
|
||||
tag_query,
|
||||
limit,
|
||||
page,
|
||||
)
|
||||
async with self.session.get(
|
||||
f"{self.base_url}/posts.json",
|
||||
params={"tags": tag_query, "limit": limit, "page": page},
|
||||
|
@ -267,19 +317,7 @@ async def fetch_post(ctx, md5) -> Optional[dict]:
|
|||
return None
|
||||
assert len(rows) == 1
|
||||
post = json.loads(rows[0][0])
|
||||
post_rating = post["rating"]
|
||||
match post_rating:
|
||||
case "g":
|
||||
rating_tag = "general"
|
||||
case "s":
|
||||
rating_tag = "sensitive"
|
||||
case "q":
|
||||
rating_tag = "questionable"
|
||||
case "e":
|
||||
rating_tag = "explicit"
|
||||
case _:
|
||||
raise AssertionError("invalid post rating {post_rating!r}")
|
||||
post["tag_string"] = post["tag_string"] + " " + rating_tag
|
||||
post["tag_string"] = tag_string_for(post)
|
||||
return post
|
||||
|
||||
|
||||
|
@ -305,6 +343,22 @@ async def insert_interrogated_result(
|
|||
await ctx.db.commit()
|
||||
|
||||
|
||||
async def process_hash(ctx, interrogator, missing_hash, semaphore, index, total):
|
||||
async with semaphore:
|
||||
log.info("interrogating %r (%d/%d)", missing_hash, index, total)
|
||||
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
|
||||
|
||||
start_ts = time.monotonic()
|
||||
tag_string = await interrogator.interrogate(ctx, post_filepath)
|
||||
end_ts = time.monotonic()
|
||||
time_taken = round(end_ts - start_ts, 10)
|
||||
|
||||
log.info("took %.5fsec, got %r", time_taken, tag_string)
|
||||
await insert_interrogated_result(
|
||||
ctx, interrogator, missing_hash, tag_string, time_taken
|
||||
)
|
||||
|
||||
|
||||
async def fight(ctx):
|
||||
interrogators = ctx.config.all_available_models
|
||||
|
||||
|
@ -312,7 +366,7 @@ async def fight(ctx):
|
|||
all_hashes = set(r[0] for r in all_rows)
|
||||
|
||||
for interrogator in interrogators:
|
||||
log.info("processing for %r", interrogator)
|
||||
log.info("processing fight for %r", interrogator)
|
||||
# calculate set of images we didn't interrogate yet
|
||||
interrogated_rows = await ctx.db.execute_fetchall(
|
||||
"select md5 from interrogated_posts where model_name = ?",
|
||||
|
@ -323,34 +377,63 @@ async def fight(ctx):
|
|||
|
||||
log.info("missing %d hashes", len(missing_hashes))
|
||||
|
||||
semaphore = asyncio.Semaphore(3)
|
||||
|
||||
tasks = []
|
||||
for index, missing_hash in enumerate(missing_hashes):
|
||||
log.info(
|
||||
"interrogating %r (%d/%d)", missing_hash, index, len(missing_hashes)
|
||||
task = process_hash(
|
||||
ctx,
|
||||
interrogator,
|
||||
missing_hash,
|
||||
semaphore,
|
||||
index + 1,
|
||||
len(missing_hashes),
|
||||
)
|
||||
post_filepath = next(DOWNLOADS.glob(f"{missing_hash}*"))
|
||||
tasks.append(task)
|
||||
|
||||
start_ts = time.monotonic()
|
||||
tag_string = await interrogator.interrogate(ctx, post_filepath)
|
||||
end_ts = time.monotonic()
|
||||
time_taken = round(end_ts - start_ts, 10)
|
||||
|
||||
log.info("took %.5fsec, got %r", time_taken, tag_string)
|
||||
await insert_interrogated_result(
|
||||
ctx, interrogator, missing_hash, tag_string, time_taken
|
||||
)
|
||||
# Run all tasks concurrently with semaphore limiting to 3 at a time
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
def score(
|
||||
danbooru_tags: Set[str], interrogator_tags: Set[str]
|
||||
) -> Tuple[decimal.Decimal, Set[str]]:
|
||||
tags_in_danbooru = danbooru_tags.intersection(interrogator_tags)
|
||||
|
||||
f1 = None
|
||||
# Handle edge cases
|
||||
if len(danbooru_tags) == 0 and len(interrogator_tags) == 0:
|
||||
f1 = decimal.Decimal("1.0") # Both empty means perfect match
|
||||
|
||||
if len(danbooru_tags) == 0 or len(interrogator_tags) == 0:
|
||||
f1 = decimal.Decimal("0.0") # One empty means no match
|
||||
|
||||
# Calculate true positives (tags that appear in both sets)
|
||||
true_positives = decimal.Decimal(len(danbooru_tags.intersection(interrogator_tags)))
|
||||
|
||||
# Calculate precision: TP / (TP + FP)
|
||||
precision = (
|
||||
true_positives / len(interrogator_tags)
|
||||
if len(interrogator_tags) > 0
|
||||
else decimal.Decimal("0.0")
|
||||
)
|
||||
|
||||
# Calculate recall: TP / (TP + FN)
|
||||
recall = (
|
||||
true_positives / len(danbooru_tags)
|
||||
if len(danbooru_tags) > 0
|
||||
else decimal.Decimal("0.0")
|
||||
)
|
||||
print("recall", recall)
|
||||
|
||||
# Handle the case where both precision and recall are 0
|
||||
if f1 is None and precision == 0 and recall == 0:
|
||||
f1 = decimal.Decimal("0.0")
|
||||
else:
|
||||
f1 = decimal.Decimal("2.0") * (precision * recall) / (precision + recall)
|
||||
|
||||
tags_not_in_danbooru = interrogator_tags - danbooru_tags
|
||||
return (
|
||||
round(
|
||||
decimal.Decimal(len(tags_in_danbooru) - len(tags_not_in_danbooru))
|
||||
/ decimal.Decimal(len(danbooru_tags)),
|
||||
10,
|
||||
),
|
||||
round(f1, 10),
|
||||
tags_not_in_danbooru,
|
||||
)
|
||||
|
||||
|
@ -365,9 +448,10 @@ async def scores(ctx):
|
|||
model_scores = defaultdict(dict)
|
||||
runtimes = defaultdict(list)
|
||||
incorrect_tags_counters = defaultdict(Counter)
|
||||
predicted_tags_counter = defaultdict(int)
|
||||
|
||||
for md5_hash in all_hashes:
|
||||
log.info("processing for %r", md5_hash)
|
||||
log.info("processing score for %r", md5_hash)
|
||||
post = await fetch_post(ctx, md5_hash)
|
||||
danbooru_tags = set(post["tag_string"].split())
|
||||
for interrogator in interrogators:
|
||||
|
@ -378,9 +462,14 @@ async def scores(ctx):
|
|||
for tag in incorrect_tags:
|
||||
incorrect_tags_counters[interrogator.model_id][tag] += 1
|
||||
|
||||
log.info(f"{interrogator.model_id} {tagging_score}")
|
||||
predicted_tags_counter[interrogator.model_id] += len(interrogator_tags)
|
||||
correct_tags = interrogator_tags.intersection(danbooru_tags)
|
||||
model_scores[interrogator.model_id][md5_hash] = {
|
||||
"score": tagging_score,
|
||||
"predicted_tags": interrogator_tags,
|
||||
"incorrect_tags": incorrect_tags,
|
||||
"correct_tags": correct_tags,
|
||||
}
|
||||
|
||||
summed_scores = {
|
||||
|
@ -425,8 +514,19 @@ async def scores(ctx):
|
|||
print(data["score"], end=",")
|
||||
|
||||
print("]")
|
||||
print("most incorrect tags", incorrect_tags_counters[model].most_common(5))
|
||||
|
||||
total_incorrect = 0
|
||||
for _, c in incorrect_tags_counters[model].most_common(10000000):
|
||||
total_incorrect += c
|
||||
print(
|
||||
"most incorrect tags from",
|
||||
total_incorrect,
|
||||
"incorrect tags",
|
||||
"predicted",
|
||||
predicted_tags_counter[model],
|
||||
"tags",
|
||||
)
|
||||
for t, c in incorrect_tags_counters[model].most_common(7):
|
||||
print("\t", t, c)
|
||||
PLOTS = Path.cwd() / "plots"
|
||||
PLOTS.mkdir(exist_ok=True)
|
||||
|
||||
|
@ -459,7 +559,12 @@ async def scores(ctx):
|
|||
log.info("plotting positive histogram...")
|
||||
plot2(PLOTS / "positive_score_histogram.png", normalized_scores, model_scores)
|
||||
log.info("plotting error rates...")
|
||||
plot3(PLOTS / "error_rate.png", normalized_scores, model_scores)
|
||||
plot3(
|
||||
PLOTS / "error_rate.png",
|
||||
PLOTS / "score_avg.png",
|
||||
normalized_scores,
|
||||
model_scores,
|
||||
)
|
||||
|
||||
|
||||
def plot2(output_path, normalized_scores, model_scores):
|
||||
|
@ -490,12 +595,13 @@ def plot2(output_path, normalized_scores, model_scores):
|
|||
pio.write_image(fig, output_path, width=1024, height=800)
|
||||
|
||||
|
||||
def plot3(output_path, normalized_scores, model_scores):
|
||||
def plot3(output_path, output_score_avg_path, normalized_scores, model_scores):
|
||||
data_for_df = {
|
||||
"model": [],
|
||||
"errors": [],
|
||||
"rating_errors": [],
|
||||
"practical_errors": [],
|
||||
"score_avg": [],
|
||||
"predicted": [],
|
||||
"correct": [],
|
||||
"incorrect": [],
|
||||
}
|
||||
|
||||
for model in sorted(
|
||||
|
@ -503,32 +609,34 @@ def plot3(output_path, normalized_scores, model_scores):
|
|||
key=lambda model: normalized_scores[model],
|
||||
reverse=True,
|
||||
):
|
||||
total_incorrect_tags = 0
|
||||
total_rating_errors = 0
|
||||
total_predicted_tags, total_incorrect_tags, total_correct_tags = 0, 0, 0
|
||||
for score_data in model_scores[model].values():
|
||||
total_predicted_tags += len(score_data["predicted_tags"])
|
||||
total_incorrect_tags += len(score_data["incorrect_tags"])
|
||||
total_rating_errors += sum(
|
||||
1
|
||||
for rating in ["general", "sensitive", "questionable", "explicit"]
|
||||
if rating in score_data["incorrect_tags"]
|
||||
)
|
||||
practical_absolute_error = total_incorrect_tags - total_rating_errors
|
||||
total_correct_tags += len(score_data["correct_tags"])
|
||||
|
||||
data_for_df["errors"].append(total_incorrect_tags)
|
||||
data_for_df["rating_errors"].append(total_rating_errors)
|
||||
data_for_df["practical_errors"].append(practical_absolute_error)
|
||||
data_for_df["score_avg"].append(normalized_scores[model])
|
||||
data_for_df["predicted"].append(total_predicted_tags)
|
||||
data_for_df["incorrect"].append(total_incorrect_tags)
|
||||
data_for_df["correct"].append(total_correct_tags)
|
||||
data_for_df["model"].append(model)
|
||||
|
||||
df = pd.DataFrame(data_for_df)
|
||||
|
||||
fig = go.Figure(
|
||||
data=[
|
||||
go.Bar(name="incorrect tags", x=df.model, y=df.errors),
|
||||
go.Bar(name="incorrect ratings", x=df.model, y=df.rating_errors),
|
||||
go.Bar(name="practical error", x=df.model, y=df.practical_errors),
|
||||
go.Bar(name="predicted tags", x=df.model, y=df.predicted),
|
||||
go.Bar(name="incorrect tags", x=df.model, y=df.incorrect),
|
||||
go.Bar(name="correct tags", x=df.model, y=df.correct),
|
||||
]
|
||||
)
|
||||
pio.write_image(fig, output_path, width=1024, height=800)
|
||||
fig2 = go.Figure(
|
||||
data=[
|
||||
go.Bar(name="score avg", x=df.model, y=df.score_avg),
|
||||
]
|
||||
)
|
||||
pio.write_image(fig2, output_score_avg_path, width=1024, height=800)
|
||||
|
||||
|
||||
async def realmain(ctx):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
aiosqlite==0.20.0
|
||||
aiohttp==3.10.0
|
||||
aiolimiter>1.1.0<2.0
|
||||
aiofiles==24.1.0
|
||||
plotly>5.15.0<6.0
|
||||
pandas==2.2.2
|
||||
kaleido==0.2.1
|
||||
aiosqlite
|
||||
aiohttp
|
||||
aiolimiter>1.1.0
|
||||
aiofiles
|
||||
plotly>5.15.0
|
||||
pandas
|
||||
kaleido
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue