avans06 commited on
Commit
8384099
·
1 Parent(s): c8d7615

Add progress display during the predict phase

Browse files
Files changed (1) hide show
  1. app.py +88 -6
app.py CHANGED
@@ -12,6 +12,7 @@ import tempfile
12
  import zipfile
13
  import re
14
  import ast
 
15
  from datetime import datetime
16
  from collections import defaultdict
17
  from classifyTags import classify_tags
@@ -111,6 +112,52 @@ def mcut_threshold(probs):
111
  thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
112
  return thresh
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  class Llama3Reorganize:
115
  def __init__(
116
  self,
@@ -355,9 +402,21 @@ class Predictor:
355
  additional_tags_prepend,
356
  additional_tags_append,
357
  tag_results,
 
358
  ):
359
- print(f"Predict load model: {model_repo}, gallery length: {len(gallery)}")
 
 
 
 
 
 
 
360
  self.load_model(model_repo)
 
 
 
 
361
  # Result
362
  txt_infos = []
363
  output_dir = tempfile.mkdtemp()
@@ -372,6 +431,11 @@ class Predictor:
372
  if llama3_reorganize_model_repo:
373
  print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
374
  llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
 
 
 
 
 
375
 
376
  prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
377
  append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
@@ -394,7 +458,7 @@ class Predictor:
394
 
395
  input_name = self.model.get_inputs()[0].name
396
  label_name = self.model.get_outputs()[0].name
397
- print(f"Gallery {idx}: Starting run wd model...")
398
  preds = self.model.run([label_name], {input_name: image})[0]
399
 
400
  labels = list(zip(self.tag_names, preds[0].astype(float)))
@@ -444,6 +508,10 @@ class Predictor:
444
  sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + sorted_general_list).replace("(", "\(").replace(")", "\)")
445
 
446
  classified_tags, unclassified_tags = classify_tags(sorted_general_list)
 
 
 
 
447
 
448
  if llama3_reorganize_model_repo:
449
  print(f"Starting reorganize with llama3...")
@@ -453,11 +521,15 @@ class Predictor:
453
  reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
454
  sorted_general_strings += "," + reorganize_strings
455
 
 
 
 
 
456
  txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
457
  txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
458
 
459
  tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
460
-
461
  except Exception as e:
462
  print(traceback.format_exc())
463
  print("Error predict: " + str(e))
@@ -475,7 +547,9 @@ class Predictor:
475
  if llama3_reorganize_model_repo:
476
  llama3_reorganize.release_vram()
477
  del llama3_reorganize
478
-
 
 
479
  print("Predict is complete.")
480
 
481
  return download, sorted_general_strings, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results
@@ -524,6 +598,14 @@ def remove_image_from_gallery(gallery: list, selected_image: str):
524
 
525
 
526
  def main():
 
 
 
 
 
 
 
 
527
  args = parse_args()
528
 
529
  predictor = Predictor()
@@ -550,7 +632,7 @@ def main():
550
  META_LLAMA_3_8B_REPO,
551
  ]
552
 
553
- with gr.Blocks(title=TITLE) as demo:
554
  gr.Markdown(
555
  value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
556
  )
@@ -561,10 +643,10 @@ def main():
561
  with gr.Column(variant="panel"):
562
  # Create an Image component for uploading images
563
  image_input = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", sources=["upload", "clipboard"], height=150)
564
- gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
565
  with gr.Row():
566
  upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
567
  remove_button = gr.Button("Remove Selected Image", size="sm")
 
568
 
569
  model_repo = gr.Dropdown(
570
  dropdown_list,
 
12
  import zipfile
13
  import re
14
  import ast
15
+ import time
16
  from datetime import datetime
17
  from collections import defaultdict
18
  from classifyTags import classify_tags
 
112
  thresh = (sorted_probs[t] + sorted_probs[t + 1]) / 2
113
  return thresh
114
 
115
+ class Timer:
116
+ def __init__(self):
117
+ self.start_time = time.perf_counter() # Record the start time
118
+ self.checkpoints = [("Start", self.start_time)] # Store checkpoints
119
+
120
+ def checkpoint(self, label="Checkpoint"):
121
+ """Record a checkpoint with a given label."""
122
+ now = time.perf_counter()
123
+ self.checkpoints.append((label, now))
124
+
125
+ def report(self, is_clear_checkpoints = True):
126
+ # Determine the max label width for alignment
127
+ max_label_length = max(len(label) for label, _ in self.checkpoints)
128
+
129
+ prev_time = self.checkpoints[0][1]
130
+ for label, curr_time in self.checkpoints[1:]:
131
+ elapsed = curr_time - prev_time
132
+ print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
133
+ prev_time = curr_time
134
+
135
+ if is_clear_checkpoints:
136
+ self.checkpoints.clear()
137
+ self.checkpoint() # Store checkpoints
138
+
139
+ def report_all(self):
140
+ """Print all recorded checkpoints and total execution time with aligned formatting."""
141
+ print("\n> Execution Time Report:")
142
+
143
+ # Determine the max label width for alignment
144
+ max_label_length = max(len(label) for label, _ in self.checkpoints) if len(self.checkpoints) > 0 else 0
145
+
146
+ prev_time = self.start_time
147
+ for label, curr_time in self.checkpoints[1:]:
148
+ elapsed = curr_time - prev_time
149
+ print(f"{label.ljust(max_label_length)}: {elapsed:.3f} seconds")
150
+ prev_time = curr_time
151
+
152
+ total_time = self.checkpoints[-1][1] - self.start_time
153
+ print(f"{'Total Execution Time'.ljust(max_label_length)}: {total_time:.3f} seconds\n")
154
+
155
+ self.checkpoints.clear()
156
+
157
+ def restart(self):
158
+ self.start_time = time.perf_counter() # Record the start time
159
+ self.checkpoints = [("Start", self.start_time)] # Store checkpoints
160
+
161
  class Llama3Reorganize:
162
  def __init__(
163
  self,
 
402
  additional_tags_prepend,
403
  additional_tags_append,
404
  tag_results,
405
+ progress=gr.Progress()
406
  ):
407
+ gallery_len = len(gallery)
408
+ print(f"Predict load model: {model_repo}, gallery length: {gallery_len}")
409
+
410
+ timer = Timer() # Create a timer
411
+ progressRatio = 0.5 if llama3_reorganize_model_repo else 1
412
+ progressTotal = gallery_len + 1
413
+ current_progress = 0
414
+
415
  self.load_model(model_repo)
416
+ current_progress += progressRatio/progressTotal;
417
+ progress(current_progress, desc="Initialize wd model finished")
418
+ timer.checkpoint(f"Initialize wd model")
419
+
420
  # Result
421
  txt_infos = []
422
  output_dir = tempfile.mkdtemp()
 
431
  if llama3_reorganize_model_repo:
432
  print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
433
  llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
434
+ current_progress += progressRatio/progressTotal;
435
+ progress(current_progress, desc="Initialize llama3 model finished")
436
+ timer.checkpoint(f"Initialize llama3 model")
437
+
438
+ timer.report()
439
 
440
  prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
441
  append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
 
458
 
459
  input_name = self.model.get_inputs()[0].name
460
  label_name = self.model.get_outputs()[0].name
461
+ print(f"Gallery {idx:02d}: Starting run wd model...")
462
  preds = self.model.run([label_name], {input_name: image})[0]
463
 
464
  labels = list(zip(self.tag_names, preds[0].astype(float)))
 
508
  sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + sorted_general_list).replace("(", "\(").replace(")", "\)")
509
 
510
  classified_tags, unclassified_tags = classify_tags(sorted_general_list)
511
+
512
+ current_progress += progressRatio/progressTotal;
513
+ progress(current_progress, desc=f"image{idx:02d}, predict finished")
514
+ timer.checkpoint(f"image{idx:02d}, predict finished")
515
 
516
  if llama3_reorganize_model_repo:
517
  print(f"Starting reorganize with llama3...")
 
521
  reorganize_strings = re.sub(r",,+", ",", reorganize_strings)
522
  sorted_general_strings += "," + reorganize_strings
523
 
524
+ current_progress += progressRatio/progressTotal;
525
+ progress(current_progress, desc=f"image{idx:02d}, llama3 reorganize finished")
526
+ timer.checkpoint(f"image{idx:02d}, llama3 reorganize finished")
527
+
528
  txt_file = self.create_file(sorted_general_strings, output_dir, image_name + ".txt")
529
  txt_infos.append({"path":txt_file, "name": image_name + ".txt"})
530
 
531
  tag_results[image_path] = { "strings": sorted_general_strings, "classified_tags": classified_tags, "rating": rating, "character_res": character_res, "general_res": general_res, "unclassified_tags": unclassified_tags }
532
+ timer.report()
533
  except Exception as e:
534
  print(traceback.format_exc())
535
  print("Error predict: " + str(e))
 
547
  if llama3_reorganize_model_repo:
548
  llama3_reorganize.release_vram()
549
  del llama3_reorganize
550
+
551
+ progress(1, desc=f"Predict completed")
552
+ timer.report_all() # Print all recorded times
553
  print("Predict is complete.")
554
 
555
  return download, sorted_general_strings, classified_tags, rating, character_res, general_res, unclassified_tags, tag_results
 
598
 
599
 
600
  def main():
601
+ # Custom CSS to set the height of the gr.Dropdown menu
602
+ css = """
603
+ div.progress-level div.progress-level-inner {
604
+ text-align: left !important;
605
+ width: 55.5% !important;
606
+ }
607
+ """
608
+
609
  args = parse_args()
610
 
611
  predictor = Predictor()
 
632
  META_LLAMA_3_8B_REPO,
633
  ]
634
 
635
+ with gr.Blocks(title=TITLE, css = css) as demo:
636
  gr.Markdown(
637
  value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
638
  )
 
643
  with gr.Column(variant="panel"):
644
  # Create an Image component for uploading images
645
  image_input = gr.Image(label="Upload an Image or clicking paste from clipboard button", type="filepath", sources=["upload", "clipboard"], height=150)
 
646
  with gr.Row():
647
  upload_button = gr.UploadButton("Upload multiple images", file_types=["image"], file_count="multiple", size="sm")
648
  remove_button = gr.Button("Remove Selected Image", size="sm")
649
+ gallery = gr.Gallery(columns=5, rows=5, show_share_button=False, interactive=True, height="500px", label="Gallery that displaying a grid of images")
650
 
651
  model_repo = gr.Dropdown(
652
  dropdown_list,