Spaces:
Running
Running
Change the global tag_results variable to use Gradio's State for execution.
Browse files
app.py
CHANGED
@@ -69,8 +69,6 @@ kaomojis = [
|
|
69 |
"||_||",
|
70 |
]
|
71 |
|
72 |
-
tag_results = {}
|
73 |
-
|
74 |
|
75 |
def parse_args() -> argparse.Namespace:
|
76 |
parser = argparse.ArgumentParser()
|
@@ -350,7 +348,9 @@ class Predictor:
|
|
350 |
llama3_reorganize_model_repo,
|
351 |
additional_tags_prepend,
|
352 |
additional_tags_append,
|
|
|
353 |
):
|
|
|
354 |
self.load_model(model_repo)
|
355 |
# Result
|
356 |
txt_infos = []
|
@@ -363,16 +363,15 @@ class Predictor:
|
|
363 |
character_res = None
|
364 |
general_res = None
|
365 |
|
366 |
-
tag_results.clear()
|
367 |
-
|
368 |
if llama3_reorganize_model_repo:
|
|
|
369 |
llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
|
370 |
|
371 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
372 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
373 |
if prepend_list and append_list:
|
374 |
append_list = [item for item in append_list if item not in prepend_list]
|
375 |
-
|
376 |
for idx, value in enumerate(gallery):
|
377 |
try:
|
378 |
image_path = value[0]
|
@@ -382,6 +381,7 @@ class Predictor:
|
|
382 |
|
383 |
input_name = self.model.get_inputs()[0].name
|
384 |
label_name = self.model.get_outputs()[0].name
|
|
|
385 |
preds = self.model.run([label_name], {input_name: image})[0]
|
386 |
|
387 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
@@ -429,6 +429,7 @@ class Predictor:
|
|
429 |
sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + prepend_list + sorted_general_list + append_list).replace("(", "\(").replace(")", "\)")
|
430 |
|
431 |
if llama3_reorganize_model_repo:
|
|
|
432 |
reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
|
433 |
reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
|
434 |
reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
|
@@ -458,9 +459,11 @@ class Predictor:
|
|
458 |
llama3_reorganize.release_vram()
|
459 |
del llama3_reorganize
|
460 |
|
461 |
-
|
|
|
|
|
462 |
|
463 |
-
def get_selection_from_gallery(gallery: list, selected_state: gr.SelectData):
|
464 |
if not selected_state:
|
465 |
return selected_state
|
466 |
|
@@ -627,14 +630,15 @@ def main():
|
|
627 |
general_res,
|
628 |
]
|
629 |
)
|
630 |
-
|
|
|
631 |
# Define the event listener to add the uploaded image to the gallery
|
632 |
image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
|
633 |
# When the upload button is clicked, add the new images to the gallery
|
634 |
upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
|
635 |
# Event to update the selected image when an image is clicked in the gallery
|
636 |
selected_image = gr.Textbox(label="Selected Image", visible=False)
|
637 |
-
gallery.select(get_selection_from_gallery, inputs=gallery, outputs=[selected_image, sorted_general_strings, rating, character_res, general_res])
|
638 |
# Event to remove a selected image from the gallery
|
639 |
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
|
640 |
|
@@ -651,8 +655,9 @@ def main():
|
|
651 |
llama3_reorganize_model_repo,
|
652 |
additional_tags_prepend,
|
653 |
additional_tags_append,
|
|
|
654 |
],
|
655 |
-
outputs=[download_file, sorted_general_strings, rating, character_res, general_res],
|
656 |
)
|
657 |
|
658 |
gr.Examples(
|
@@ -667,7 +672,7 @@ def main():
|
|
667 |
],
|
668 |
)
|
669 |
|
670 |
-
demo.queue(max_size=
|
671 |
demo.launch(inbrowser=True)
|
672 |
|
673 |
|
|
|
69 |
"||_||",
|
70 |
]
|
71 |
|
|
|
|
|
72 |
|
73 |
def parse_args() -> argparse.Namespace:
|
74 |
parser = argparse.ArgumentParser()
|
|
|
348 |
llama3_reorganize_model_repo,
|
349 |
additional_tags_prepend,
|
350 |
additional_tags_append,
|
351 |
+
tag_results,
|
352 |
):
|
353 |
+
print(f"Predict load model: {model_repo}, gallery length: {len(gallery)}")
|
354 |
self.load_model(model_repo)
|
355 |
# Result
|
356 |
txt_infos = []
|
|
|
363 |
character_res = None
|
364 |
general_res = None
|
365 |
|
|
|
|
|
366 |
if llama3_reorganize_model_repo:
|
367 |
+
print(f"Llama3 reorganize load model {llama3_reorganize_model_repo}")
|
368 |
llama3_reorganize = Llama3Reorganize(llama3_reorganize_model_repo, loadModel=True)
|
369 |
|
370 |
prepend_list = [tag.strip() for tag in additional_tags_prepend.split(",") if tag.strip()]
|
371 |
append_list = [tag.strip() for tag in additional_tags_append.split(",") if tag.strip()]
|
372 |
if prepend_list and append_list:
|
373 |
append_list = [item for item in append_list if item not in prepend_list]
|
374 |
+
|
375 |
for idx, value in enumerate(gallery):
|
376 |
try:
|
377 |
image_path = value[0]
|
|
|
381 |
|
382 |
input_name = self.model.get_inputs()[0].name
|
383 |
label_name = self.model.get_outputs()[0].name
|
384 |
+
print(f"Gallery {idx}: Starting run wd model...")
|
385 |
preds = self.model.run([label_name], {input_name: image})[0]
|
386 |
|
387 |
labels = list(zip(self.tag_names, preds[0].astype(float)))
|
|
|
429 |
sorted_general_strings = ", ".join((character_list if characters_merge_enabled else []) + prepend_list + sorted_general_list + append_list).replace("(", "\(").replace(")", "\)")
|
430 |
|
431 |
if llama3_reorganize_model_repo:
|
432 |
+
print(f"Starting reorganize with llama3...")
|
433 |
reorganize_strings = llama3_reorganize.reorganize(sorted_general_strings)
|
434 |
reorganize_strings = re.sub(r" *Title: *", "", reorganize_strings)
|
435 |
reorganize_strings = re.sub(r"\n+", ",", reorganize_strings)
|
|
|
459 |
llama3_reorganize.release_vram()
|
460 |
del llama3_reorganize
|
461 |
|
462 |
+
print("Predict is complete.")
|
463 |
+
|
464 |
+
return download, sorted_general_strings, rating, character_res, general_res, tag_results
|
465 |
|
466 |
+
def get_selection_from_gallery(gallery: list, tag_results: dict, selected_state: gr.SelectData):
|
467 |
if not selected_state:
|
468 |
return selected_state
|
469 |
|
|
|
630 |
general_res,
|
631 |
]
|
632 |
)
|
633 |
+
|
634 |
+
tag_results = gr.State({})
|
635 |
# Define the event listener to add the uploaded image to the gallery
|
636 |
image_input.change(append_gallery, inputs=[gallery, image_input], outputs=[gallery, image_input])
|
637 |
# When the upload button is clicked, add the new images to the gallery
|
638 |
upload_button.upload(extend_gallery, inputs=[gallery, upload_button], outputs=gallery)
|
639 |
# Event to update the selected image when an image is clicked in the gallery
|
640 |
selected_image = gr.Textbox(label="Selected Image", visible=False)
|
641 |
+
gallery.select(get_selection_from_gallery, inputs=[gallery, tag_results], outputs=[selected_image, sorted_general_strings, rating, character_res, general_res])
|
642 |
# Event to remove a selected image from the gallery
|
643 |
remove_button.click(remove_image_from_gallery, inputs=[gallery, selected_image], outputs=gallery)
|
644 |
|
|
|
655 |
llama3_reorganize_model_repo,
|
656 |
additional_tags_prepend,
|
657 |
additional_tags_append,
|
658 |
+
tag_results,
|
659 |
],
|
660 |
+
outputs=[download_file, sorted_general_strings, rating, character_res, general_res, tag_results,],
|
661 |
)
|
662 |
|
663 |
gr.Examples(
|
|
|
672 |
],
|
673 |
)
|
674 |
|
675 |
+
demo.queue(max_size=2)
|
676 |
demo.launch(inbrowser=True)
|
677 |
|
678 |
|