avans06 commited on
Commit
2f1b118
·
1 Parent(s): d8681b3

Change the global tag_results variable to use Gradio's State for execution.

Browse files
Files changed (1) hide show
  1. app.py +16 -11
app.py CHANGED
@@ -69,8 +69,6 @@ kaomojis = [
69
  "||_||",
70
  ]
71
 
72
- tag_results = {}
73
-
74
 
75
  def parse_args() -> argparse.Namespace:
76
  parser = argparse.ArgumentParser()
@@ -350,7 +348,9 @@ class Predictor:
350
  llama3_reorganize_model_repo,
351
  additional_tags_prepend,
352
  additional_tags_append,
 
353
  ):
 
354
  self.load_model(model_repo)
355
  # Result
356
  txt_infos = []
@@ -363,16 +363,15 @@ class Predictor:
363
  character_res = None
364
  general_res = None
365
 
366
- tag_results.clear()
367
-
368
  if llama3_reorganize_model_repo:
 
369
  llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
370
 
371
  prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
372
  append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
373
  if prepend_list and append_list:
374
  append_list = [item for item in append_list if item not in prepend_list]
375
-
376
  for idx, value in enumerate(gallery):
377
  try:
378
  image_path = value[0]
@@ -382,6 +381,7 @@ class Predictor:
382
 
383
  input_name = self.model.get_inputs()[0].name
384
  label_name = self.model.get_outputs()[0].name
 
385
  preds = self.model.run([label_name], {input_name: image})[0]
386
 
387
  labels = list(zip(self.tag_names, preds[0].astype(float)))
@@ -429,6 +429,7 @@ class Predictor:
429
  sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + prepend_list + sorted_general_list + append_list).replace("(", "\(").replace(")", "\)")
430
 
431
  if llama3_reorganize_model_repo:
 
432
  reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
433
  reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
434
  reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
@@ -458,9 +459,11 @@ class Predictor:
458
  llama3_reorganize.release_vram()
459
  del llama3_reorganize
460
 
461
- return download, sorted_general_strings, rating, character_res, general_res
 
 
462
 
463
- def get_selection_from_gallery(gallery: list, selected_state: gr.SelectData):
464
  if not selected_state:
465
  return selected_state
466
 
@@ -627,14 +630,15 @@ def main():
627
  general_res,
628
  ]
629
  )
630
-
 
631
  # Define the event listener to add the uploaded image to the gallery
632
  image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
633
  # When the upload button is clicked, add the new images to the gallery
634
  upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
635
  # Event to update the selected image when an image is clicked in the gallery
636
  selected_image = gr.Textbox(label="Selected Image", visible=False)
637
- gallery.select(get_selection_from_gallery, inputs=gallery, outputs=[selected_image, sorted_general_strings, rating, character_res, general_res])
638
  # Event to remove a selected image from the gallery
639
  remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
640
 
@@ -651,8 +655,9 @@ def main():
651
  llama3_reorganize_model_repo,
652
  additional_tags_prepend,
653
  additional_tags_append,
 
654
  ],
655
- outputs=[download_file, sorted_general_strings, rating, character_res, general_res],
656
  )
657
 
658
  gr.Examples(
@@ -667,7 +672,7 @@ def main():
667
  ],
668
  )
669
 
670
- demo.queue(max_size=10)
671
  demo.launch(inbrowser=True)
672
 
673
 
 
69
  "||_||",
70
  ]
71
 
 
 
72
 
73
  def parse_args() -> argparse.Namespace:
74
  parser = argparse.ArgumentParser()
 
348
  llama3_reorganize_model_repo,
349
  additional_tags_prepend,
350
  additional_tags_append,
351
+ tag_results,
352
  ):
353
+ print(f"Predict load model: {model_repo}, gallery length: {len(gallery)}")
354
  self.load_model(model_repo)
355
  # Result
356
  txt_infos = []
 
363
  character_res = None
364
  general_res = None
365
 
 
 
366
  if llama3_reorganize_model_repo:
367
+ print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
368
  llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
369
 
370
  prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
371
  append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
372
  if prepend_list and append_list:
373
  append_list = [item for item in append_list if item not in prepend_list]
374
+
375
  for idx, value in enumerate(gallery):
376
  try:
377
  image_path = value[0]
 
381
 
382
  input_name = self.model.get_inputs()[0].name
383
  label_name = self.model.get_outputs()[0].name
384
+ print(f"Gallery {idx}: Starting run wd model...")
385
  preds = self.model.run([label_name], {input_name: image})[0]
386
 
387
  labels = list(zip(self.tag_names, preds[0].astype(float)))
 
429
  sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + prepend_list + sorted_general_list + append_list).replace("(", "\(").replace(")", "\)")
430
 
431
  if llama3_reorganize_model_repo:
432
+ print(f"Starting reorganize with llama3...")
433
  reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
434
  reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
435
  reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
 
459
  llama3_reorganize.release_vram()
460
  del llama3_reorganize
461
 
462
+ print("Predict is complete.")
463
+
464
+ return download, sorted_general_strings, rating, character_res, general_res, tag_results
465
 
466
+ def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
467
  if not selected_state:
468
  return selected_state
469
 
 
630
  general_res,
631
  ]
632
  )
633
+
634
+ tag_results = gr.State({})
635
  # Define the event listener to add the uploaded image to the gallery
636
  image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
637
  # When the upload button is clicked, add the new images to the gallery
638
  upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
639
  # Event to update the selected image when an image is clicked in the gallery
640
  selected_image = gr.Textbox(label="Selected Image", visible=False)
641
+ gallery.select(get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[selected_image, sorted_general_strings, rating, character_res, general_res])
642
  # Event to remove a selected image from the gallery
643
  remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
644
 
 
655
  llama3_reorganize_model_repo,
656
  additional_tags_prepend,
657
  additional_tags_append,
658
+ tag_results,
659
  ],
660
+ outputs=[download_file, sorted_general_strings, rating, character_res, general_res, tag_results,],
661
  )
662
 
663
  gr.Examples(
 
672
  ],
673
  )
674
 
675
+ demo.queue(max_size=2)
676
  demo.launch(inbrowser=True)
677
 
678