Spaces:
Running
Running
Add progress display during the predict phase
Browse files
app.py
CHANGED
@@ -12,6 +12,7 @@ import tempfile
|
|
12 |
import zipfile
|
13 |
import re
|
14 |
import ast
|
|
|
15 |
from datetime import datetime
|
16 |
from collections import defaultdict
|
17 |
from classifyTags import classify_tags
|
@@ -111,6 +112,52 @@ def mcut_threshold(probs):
|
|
111 |
thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
|
112 |
return thresh
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
class Llama3Reorganize:
|
115 |
def __init__(
|
116 |
self,
|
@@ -355,9 +402,21 @@ class Predictor:
|
|
355 |
additional_tags_prepend,
|
356 |
additional_tags_append,
|
357 |
tag_results,
|
|
|
358 |
):
|
359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
360 |
self.load_model(model_repo)
|
|
|
|
|
|
|
|
|
361 |
# Result
|
362 |
txt_infos = []
|
363 |
output_dir = tempfile.mkdtemp()
|
@@ -372,6 +431,11 @@ class Predictor:
|
|
372 |
if llama3_reorganize_model_repo:
|
373 |
print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
|
374 |
llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
|
|
|
|
|
|
|
|
|
|
|
375 |
|
376 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
377 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
@@ -394,7 +458,7 @@ class Predictor:
|
|
394 |
|
395 |
input_name = self.model.get_inputs()[0].name
|
396 |
label_name = self.model.get_outputs()[0].name
|
397 |
-
print(f"Gallery {idx}: Starting run wd model...")
|
398 |
preds = self.model.run([label_name], {input_name: image})[0]
|
399 |
|
400 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
@@ -444,6 +508,10 @@ class Predictor:
|
|
444 |
sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + sorted_general_list).replace("(", "\(").replace(")", "\)")
|
445 |
|
446 |
classified_tags, unclassified_tags = classify_tags(sorted_general_list)
|
|
|
|
|
|
|
|
|
447 |
|
448 |
if llama3_reorganize_model_repo:
|
449 |
print(f"Starting reorganize with llama3...")
|
@@ -453,11 +521,15 @@ class Predictor:
|
|
453 |
reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
|
454 |
sorted_general_strings += "," + reorganize_strings
|
455 |
|
|
|
|
|
|
|
|
|
456 |
txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
|
457 |
txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
|
458 |
|
459 |
tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
|
460 |
-
|
461 |
except Exception as e:
|
462 |
print(traceback.format_exc())
|
463 |
print("Error predict: " + str(e))
|
@@ -475,7 +547,9 @@ class Predictor:
|
|
475 |
if llama3_reorganize_model_repo:
|
476 |
llama3_reorganize.release_vram()
|
477 |
del llama3_reorganize
|
478 |
-
|
|
|
|
|
479 |
print("Predict is complete.")
|
480 |
|
481 |
return download, sorted_general_strings, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results
|
@@ -524,6 +598,14 @@ def remove_image_from_gallery(gallery: list, selected_image: str):
|
|
524 |
|
525 |
|
526 |
def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
527 |
args = parse_args()
|
528 |
|
529 |
predictor = Predictor()
|
@@ -550,7 +632,7 @@ def main():
|
|
550 |
META_LLAMA_3_8B_REPO,
|
551 |
]
|
552 |
|
553 |
-
with gr.Blocks(title=TITLE) as demo:
|
554 |
gr.Markdown(
|
555 |
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
|
556 |
)
|
@@ -561,10 +643,10 @@ def main():
|
|
561 |
with gr.Column(variant="panel"):
|
562 |
# Create an Image component for uploading images
|
563 |
image_input = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", sources=["upload", "clipboard"], height=150)
|
564 |
-
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
|
565 |
with gr.Row():
|
566 |
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
|
567 |
remove_button = gr.Button("Remove Selected Image", size="sm")
|
|
|
568 |
|
569 |
model_repo = gr.Dropdown(
|
570 |
dropdown_list,
|
|
|
12 |
import zipfile
|
13 |
import re
|
14 |
import ast
|
15 |
+
import time
|
16 |
from datetime import datetime
|
17 |
from collections import defaultdict
|
18 |
from classifyTags import classify_tags
|
|
|
112 |
thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
|
113 |
return thresh
|
114 |
|
115 |
+
class Timer:
|
116 |
+
def __init__(self):
|
117 |
+
self.start_time = time.perf_counter() # Record the start time
|
118 |
+
self.checkpoints = [("Start", self.start_time)] # Store checkpoints
|
119 |
+
|
120 |
+
def checkpoint(self, label="Checkpoint"):
|
121 |
+
"""Record a checkpoint with a given label."""
|
122 |
+
now = time.perf_counter()
|
123 |
+
self.checkpoints.append((label, now))
|
124 |
+
|
125 |
+
def report(self, is_clear_checkpoints = True):
|
126 |
+
# Determine the max label width for alignment
|
127 |
+
max_label_length = max(len(label) for label, _ in self.checkpoints)
|
128 |
+
|
129 |
+
prev_time = self.checkpoints[0][1]
|
130 |
+
for label, curr_time in self.checkpoints[1:]:
|
131 |
+
elapsed = curr_time - prev_time
|
132 |
+
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
133 |
+
prev_time = curr_time
|
134 |
+
|
135 |
+
if is_clear_checkpoints:
|
136 |
+
self.checkpoints.clear()
|
137 |
+
self.checkpoint() # Store checkpoints
|
138 |
+
|
139 |
+
def report_all(self):
|
140 |
+
"""Print all recorded checkpoints and total execution time with aligned formatting."""
|
141 |
+
print("\n> Execution Time Report:")
|
142 |
+
|
143 |
+
# Determine the max label width for alignment
|
144 |
+
max_label_length = max(len(label) for label, _ in self.checkpoints) if len(self.checkpoints) > 0 else 0
|
145 |
+
|
146 |
+
prev_time = self.start_time
|
147 |
+
for label, curr_time in self.checkpoints[1:]:
|
148 |
+
elapsed = curr_time - prev_time
|
149 |
+
print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
|
150 |
+
prev_time = curr_time
|
151 |
+
|
152 |
+
total_time = self.checkpoints[-1][1] - self.start_time
|
153 |
+
print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
|
154 |
+
|
155 |
+
self.checkpoints.clear()
|
156 |
+
|
157 |
+
def restart(self):
|
158 |
+
self.start_time = time.perf_counter() # Record the start time
|
159 |
+
self.checkpoints = [("Start", self.start_time)] # Store checkpoints
|
160 |
+
|
161 |
class Llama3Reorganize:
|
162 |
def __init__(
|
163 |
self,
|
|
|
402 |
additional_tags_prepend,
|
403 |
additional_tags_append,
|
404 |
tag_results,
|
405 |
+
progress=gr.Progress()
|
406 |
):
|
407 |
+
gallery_len = len(gallery)
|
408 |
+
print(f"Predict load model: {model_repo}, gallery length: {gallery_len}")
|
409 |
+
|
410 |
+
timer = Timer() # Create a timer
|
411 |
+
progressRatio = 0.5 if llama3_reorganize_model_repo else 1
|
412 |
+
progressTotal = gallery_len + 1
|
413 |
+
current_progress = 0
|
414 |
+
|
415 |
self.load_model(model_repo)
|
416 |
+
current_progress += progressRatio/progressTotal;
|
417 |
+
progress(current_progress, desc="Initialize wd model finished")
|
418 |
+
timer.checkpoint(f"Initialize wd model")
|
419 |
+
|
420 |
# Result
|
421 |
txt_infos = []
|
422 |
output_dir = tempfile.mkdtemp()
|
|
|
431 |
if llama3_reorganize_model_repo:
|
432 |
print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
|
433 |
llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
|
434 |
+
current_progress += progressRatio/progressTotal;
|
435 |
+
progress(current_progress, desc="Initialize llama3 model finished")
|
436 |
+
timer.checkpoint(f"Initialize llama3 model")
|
437 |
+
|
438 |
+
timer.report()
|
439 |
|
440 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
441 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
|
|
458 |
|
459 |
input_name = self.model.get_inputs()[0].name
|
460 |
label_name = self.model.get_outputs()[0].name
|
461 |
+
print(f"Gallery {idx:02d}: Starting run wd model...")
|
462 |
preds = self.model.run([label_name], {input_name: image})[0]
|
463 |
|
464 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
|
|
508 |
sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + sorted_general_list).replace("(", "\(").replace(")", "\)")
|
509 |
|
510 |
classified_tags, unclassified_tags = classify_tags(sorted_general_list)
|
511 |
+
|
512 |
+
current_progress += progressRatio/progressTotal;
|
513 |
+
progress(current_progress, desc=f"image{idx:02d}, predict finished")
|
514 |
+
timer.checkpoint(f"image{idx:02d}, predict finished")
|
515 |
|
516 |
if llama3_reorganize_model_repo:
|
517 |
print(f"Starting reorganize with llama3...")
|
|
|
521 |
reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
|
522 |
sorted_general_strings += "," + reorganize_strings
|
523 |
|
524 |
+
current_progress += progressRatio/progressTotal;
|
525 |
+
progress(current_progress, desc=f"image{idx:02d}, llama3 reorganize finished")
|
526 |
+
timer.checkpoint(f"image{idx:02d}, llama3 reorganize finished")
|
527 |
+
|
528 |
txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
|
529 |
txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
|
530 |
|
531 |
tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
|
532 |
+
timer.report()
|
533 |
except Exception as e:
|
534 |
print(traceback.format_exc())
|
535 |
print("Error predict: " + str(e))
|
|
|
547 |
if llama3_reorganize_model_repo:
|
548 |
llama3_reorganize.release_vram()
|
549 |
del llama3_reorganize
|
550 |
+
|
551 |
+
progress(1, desc=f"Predict completed")
|
552 |
+
timer.report_all() # Print all recorded times
|
553 |
print("Predict is complete.")
|
554 |
|
555 |
return download, sorted_general_strings, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results
|
|
|
598 |
|
599 |
|
600 |
def main():
|
601 |
+
# Custom CSS to set the height of the gr.Dropdown menu
|
602 |
+
css = """
|
603 |
+
div.progress-level div.progress-level-inner {
|
604 |
+
text-align: left !important;
|
605 |
+
width: 55.5% !important;
|
606 |
+
}
|
607 |
+
"""
|
608 |
+
|
609 |
args = parse_args()
|
610 |
|
611 |
predictor = Predictor()
|
|
|
632 |
META_LLAMA_3_8B_REPO,
|
633 |
]
|
634 |
|
635 |
+
with gr.Blocks(title=TITLE, css = css) as demo:
|
636 |
gr.Markdown(
|
637 |
value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
|
638 |
)
|
|
|
643 |
with gr.Column(variant="panel"):
|
644 |
# Create an Image component for uploading images
|
645 |
image_input = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", sources=["upload", "clipboard"], height=150)
|
|
|
646 |
with gr.Row():
|
647 |
upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
|
648 |
remove_button = gr.Button("Remove Selected Image", size="sm")
|
649 |
+
gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
|
650 |
|
651 |
model_repo = gr.Dropdown(
|
652 |
dropdown_list,
|