Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| from env import config_env | |
| config_env() | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| import cv2 | |
| import dotenv | |
| dotenv.load_dotenv() | |
| import numpy as np | |
| import gradio as gr | |
| import glob | |
| from inference_sam import segmentation_sam | |
| from explanations import explain | |
| from inference_resnet import get_triplet_model | |
| from inference_resnet_v2 import get_resnet_model,inference_resnet_embedding_v2,inference_resnet_finer_v2 | |
| from inference_beit import get_triplet_model_beit | |
| import pathlib | |
| import tensorflow as tf | |
| from closest_sample import get_images,get_diagram | |
| if not os.path.exists('images'): | |
| REPO_ID='Serrelab/image_examples_gradio' | |
| snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='dataset',local_dir='images') | |
| if not os.path.exists('dataset'): | |
| REPO_ID='Serrelab/Fossils' | |
| token = os.environ.get('READ_TOKEN') | |
| print(f"Read token:{token}") | |
| if token is None: | |
| print("warning! A read token in env variables is needed for authentication.") | |
| snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset') | |
| HEADER = ''' | |
| <h2><b>Official Gradio Demo</b></h2><h2><a href='https://huggingface.co/spaces/Serrelab/fossil_app' target='_blank'><b>Identifying Florissant Leaf Fossils to Family using Deep Neural Networks </b></a></h2> | |
| Code: <a href='https://github.com/orgs/serre-lab/projects/2' target='_blank'>GitHub</a>. Paper: <a href='' target='_blank'>ArXiv</a>. | |
| ''' | |
| """ | |
| **Fossil** a brief intro to the project. | |
| # βοΈβοΈβοΈ**Important Notes:** | |
| # - some notes to users some notes to users some notes to users some notes to users some notes to users some notes to users . | |
| # - some notes to users some notes to users some notes to users some notes to users some notes to users some notes to users. | |
| """ | |
| USER_GUIDE = """ | |
| <div style='background-color: #f0f0f0; padding: 20px; border-radius: 10px;'> | |
| <h2>βοΈ User Guide</h2> | |
| <p>Welcome to the interactive fossil exploration tool. Here's how to get started:</p> | |
| <ul> | |
| <li><strong>Upload an Image:</strong> Drag and drop or choose from given samples to upload images of fossils.</li> | |
| <li><strong>Process Image:</strong> After uploading, click the 'Process Image' button to analyze the image.</li> | |
| <li><strong>Explore Results:</strong> Switch to the 'Workbench' tab to check out detailed analysis and results.</li> | |
| </ul> | |
| <h3>Tips</h3> | |
| <ul> | |
| <li>Zoom into images on the workbench for finer details.</li> | |
| <li>Use the examples below as references for what types of images to upload.</li> | |
| </ul> | |
| <p>Enjoy exploring! π</p> | |
| </div> | |
| """ | |
| TIPS = """ | |
| ## Tips | |
| - Zoom into images on the workbench for finer details. | |
| - Use the examples below as references for what types of images to upload. | |
| Enjoy exploring! | |
| """ | |
| CITATION = ''' | |
| π§ **Contact** <br> | |
| If you have any questions, feel free to contact us at <b>[email protected]</b>. | |
| ''' | |
| """ | |
| π **Citation** | |
| cite using this bibtex:... | |
| ``` | |
| ``` | |
| π **License** | |
| """ | |
| def get_model(model_name): | |
| if model_name=='Mummified 170': | |
| n_classes = 170 | |
| model = get_triplet_model(input_shape = (600, 600, 3), | |
| embedding_units = 256, | |
| embedding_depth = 2, | |
| backbone_class=tf.keras.applications.ResNet50V2, | |
| nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2') | |
| model.load_weights('model_classification/mummified-170.h5') | |
| elif model_name=='Rock 170': | |
| n_classes = 171 | |
| model = get_triplet_model(input_shape = (600, 600, 3), | |
| embedding_units = 256, | |
| embedding_depth = 2, | |
| backbone_class=tf.keras.applications.ResNet50V2, | |
| nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2') | |
| model.load_weights('model_classification/rock-170.h5') | |
| elif model_name == 'Fossils 142': | |
| n_classes = 142 | |
| model = get_triplet_model_beit(input_shape = (384, 384, 3), | |
| embedding_units = 256, | |
| embedding_depth = 2, | |
| n_classes = n_classes) | |
| model.load_weights('model_classification/fossil-142.h5') | |
| elif model_name == 'Fossils new': | |
| n_classes = 142 | |
| model = get_triplet_model_beit(input_shape = (384, 384, 3), | |
| embedding_units = 256, | |
| embedding_depth = 2, | |
| n_classes = n_classes) | |
| model.load_weights('model_classification/fossil-new.h5') | |
| elif model_name == 'Fossils': | |
| n_classes = 142 | |
| model,_,_ = get_resnet_model('model_classification/fossil-model.h5') | |
| else: | |
| raise ValueError(f"Model name '{model_name}' is not recognized") | |
| return model,n_classes | |
| def segment_image(input_image): | |
| img = segmentation_sam(input_image) | |
| return img | |
| def classify_image(input_image, model_name): | |
| #segmented_image = segment_image(input_image) | |
| if 'Rock 170' ==model_name: | |
| from inference_resnet import inference_resnet_finer | |
| model,n_classes= get_model(model_name) | |
| result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes) | |
| return result | |
| elif 'Mummified 170' ==model_name: | |
| from inference_resnet import inference_resnet_finer | |
| model, n_classes= get_model(model_name) | |
| result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes) | |
| return result | |
| elif 'Fossils 142' ==model_name: | |
| from inference_beit import inference_resnet_finer_beit | |
| model,n_classes = get_model(model_name) | |
| result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes) | |
| return result | |
| elif 'Fossils new' ==model_name: | |
| from inference_beit import inference_resnet_finer_beit | |
| model,n_classes = get_model(model_name) | |
| result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes) | |
| return result | |
| elif 'Fossils' ==model_name: | |
| from inference_beit import inference_resnet_finer_beit | |
| model,n_classes = get_model(model_name) | |
| result = inference_resnet_finer_v2(input_image,model,size=384,n_classes=n_classes) | |
| return result | |
| return None | |
| def get_embeddings(input_image,model_name): | |
| if 'Rock 170' ==model_name: | |
| from inference_resnet import inference_resnet_embedding | |
| model,n_classes= get_model(model_name) | |
| result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes) | |
| return result | |
| elif 'Mummified 170' ==model_name: | |
| from inference_resnet import inference_resnet_embedding | |
| model, n_classes= get_model(model_name) | |
| result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes) | |
| return result | |
| elif 'Fossils 142' ==model_name: | |
| from inference_beit import inference_resnet_embedding_beit | |
| model,n_classes = get_model(model_name) | |
| result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes) | |
| return result | |
| elif 'Fossils new' ==model_name: | |
| from inference_beit import inference_resnet_embedding_beit | |
| model,n_classes = get_model(model_name) | |
| result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes) | |
| return result | |
| elif 'Fossils' ==model_name: | |
| from inference_beit import inference_resnet_embedding_beit | |
| model,n_classes = get_model(model_name) | |
| result = inference_resnet_embedding_v2(input_image,model,size=384,n_classes=n_classes) | |
| return result | |
| return None | |
| def find_closest(input_image,model_name): | |
| embedding = get_embeddings(input_image,model_name) | |
| classes, paths = get_images(embedding) | |
| #outputs = classes+paths | |
| return classes,paths | |
| def generate_diagram_closest(input_image,model_name,top_k): | |
| embedding = get_embeddings(input_image,model_name) | |
| diagram_path = get_diagram(embedding,top_k) | |
| return diagram_path | |
| def explain_image(input_image,model_name,explain_method,nb_samples): | |
| model,n_classes= get_model(model_name) | |
| if model_name=='Fossils 142' or 'Fossils new': | |
| size = 384 | |
| else: | |
| size = 600 | |
| #saliency, integrated, smoothgrad, | |
| classes,exp_list = explain(model,input_image,explain_method,nb_samples,size = size, n_classes=n_classes) | |
| #original = saliency + integrated + smoothgrad | |
| print('done') | |
| return classes,exp_list | |
| def setup_examples(): | |
| paths = sorted(pathlib.Path('images/').rglob('*.jpg')) | |
| samples = [path.as_posix() for path in paths if 'selected fossil examples' in str(path)][:12] | |
| examples_fossils = gr.Examples(samples, inputs=input_image,examples_per_page=5,label='Fossils Examples from the dataset') | |
| samples=[[path.as_posix()] for path in paths if 'leaves' in str(path) ][:19] | |
| examples_leaves = gr.Examples(samples, inputs=input_image,examples_per_page=12,label='Leaves Examples from the dataset') | |
| return examples_fossils,examples_leaves | |
| def preprocess_image(image, output_size=(300, 300)): | |
| #shape (height, width, channels) | |
| h, w = image.shape[:2] | |
| #padding | |
| if h > w: | |
| padding = (h - w) // 2 | |
| image_padded = cv2.copyMakeBorder(image, 0, 0, padding, padding, cv2.BORDER_CONSTANT, value=[0, 0, 0]) | |
| else: | |
| padding = (w - h) // 2 | |
| image_padded = cv2.copyMakeBorder(image, padding, padding, 0, 0, cv2.BORDER_CONSTANT, value=[0, 0, 0]) | |
| # resize | |
| image_resized = cv2.resize(image_padded, output_size, interpolation=cv2.INTER_AREA) | |
| return image_resized | |
| def update_display(image): | |
| original_image = image | |
| processed_image = preprocess_image(image) | |
| instruction = "Image ready. Please switch to the 'Specimen Workbench' tab to check out further analysis and outputs." | |
| model_name = "Fossils" | |
| # gr.Dropdown( | |
| # ["Mummified 170", "Rock 170","Fossils 142","Fossils new"], | |
| # multiselect=False, | |
| # value="Fossils new", # default option | |
| # label="Model", | |
| # interactive=True, | |
| # info="Choose the model you'd like to use" | |
| # ) | |
| explain_method = "Rise" | |
| # gr.Dropdown( | |
| # ["Sobol", "HSIC","Rise","Saliency"], | |
| # multiselect=False, | |
| # value="Rise", # default option | |
| # label="Explain method", | |
| # interactive=True, | |
| # info="Choose one method to explain the model" | |
| # ) | |
| sampling_size = 5 | |
| # gr.Slider(1, 5000, value=2000, label="Sampling Size in Rise",interactive=True,visible=True, | |
| # info="Choose between 1 and 5000") | |
| top_k = 50 | |
| # gr.Slider(10,200,value=50,label="Number of Closest Samples for Distribution Chart",interactive=True,info="Choose between 10 and 200") | |
| class_predicted = None # gr.Label(label='Class Predicted',num_top_classes=10) | |
| exp_gallery = None | |
| # gr.Gallery(label="Explanation Heatmaps for top 5 predicted classes", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None) | |
| closest_gallery = None | |
| # gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None) | |
| diagram= None | |
| # gr.Image(label = 'Bar Chart') | |
| return original_image,processed_image,processed_image,instruction,model_name,explain_method,sampling_size,top_k,class_predicted,exp_gallery,closest_gallery,diagram | |
| def update_slider_visibility(explain_method): | |
| bool = explain_method=="Rise" | |
| return {sampling_size: gr.Slider(1, 5000, value=2000, label="Sampling Size in Rise", visible=bool, interactive=True)} | |
| #minimalist theme | |
| with gr.Blocks(theme='sudeepshouche/minimalist') as demo: | |
| with gr.Tab(" Florrissant Fossils"): | |
| gr.Markdown(HEADER) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown(USER_GUIDE) | |
| with gr.Column(scale=2): | |
| with gr.Column(scale=2): | |
| instruction_text = gr.Textbox(label="Instructions", value="Upload/Choose an image and click 'Process Image'.") | |
| input_image = gr.Image(label="Input",width="100%",container=True) | |
| process_button = gr.Button("Process Image") | |
| with gr.Column(scale=1): | |
| examples_fossils,examples_leaves = setup_examples() | |
| gr.Markdown(CITATION) | |
| with gr.Tab("Specimen Workbench"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| original_image = gr.Image(visible = False) | |
| workbench_image = gr.Image(label="Workbench Image") | |
| classify_image_button = gr.Button("Classify Image") | |
| # with gr.Column(): | |
| # #segmented_image = gr.outputs.Image(label="SAM output",type='numpy') | |
| # segmented_image=gr.Image(label="Segmented Image", type='numpy') | |
| # segment_button = gr.Button("Segment Image") | |
| # #classify_segmented_button = gr.Button("Classify Segmented Image") | |
| with gr.Column(): | |
| model_name = gr.Dropdown( | |
| ["Mummified 170", "Rock 170","Fossils 142","Fossils new","Fossils"], | |
| multiselect=False, | |
| value="Fossils", # default option | |
| label="Model", | |
| interactive=True, | |
| info="Choose the model you'd like to use" | |
| ) | |
| explain_method = gr.Dropdown( | |
| ["Sobol", "HSIC","Rise","Saliency"], | |
| multiselect=False, | |
| value="Rise", # default option | |
| label="Explain method", | |
| interactive=True, | |
| info="Choose one method to explain the model" | |
| ) | |
| # explain_method = gr.CheckboxGroup(["Sobol", "HSIC","Rise","Saliency"], | |
| # label="explain method", | |
| # value="Rise", | |
| # multiselect=False, | |
| # interactive=True,) | |
| sampling_size = gr.Slider(1, 30, value=5, label="Sampling Size in Rise",interactive=True,visible=True, | |
| info="Choose between 1 and 30") | |
| top_k = gr.Slider(10,200,value=50,label="Number of Closest Samples for Distribution Chart",interactive=True,info="Choose between 10 and 200") | |
| explain_method.change( | |
| fn=update_slider_visibility, | |
| inputs=explain_method, | |
| outputs=sampling_size | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| class_predicted = gr.Label(label='Class Predicted',num_top_classes=10) | |
| with gr.Column(scale=4): | |
| with gr.Accordion("Explanations "): | |
| gr.Markdown("Computing Explanations from the model") | |
| with gr.Column(): | |
| with gr.Row(): | |
| #original_input = gr.Image(label="Original Frame") | |
| #saliency = gr.Image(label="saliency") | |
| #gradcam = gr.Image(label='integraged gradients') | |
| #guided_gradcam = gr.Image(label='gradcam') | |
| #guided_backprop = gr.Image(label='guided backprop') | |
| # exp1 = gr.Image(label = 'Class_name1') | |
| # exp2= gr.Image(label = 'Class_name2') | |
| # exp3= gr.Image(label = 'Class_name3') | |
| # exp4= gr.Image(label = 'Class_name4') | |
| # exp5= gr.Image(label = 'Class_name5') | |
| exp_gallery = gr.Gallery(label="Explanation Heatmaps for top 5 predicted classes", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None) | |
| generate_explanations = gr.Button("Generate Explanations") | |
| # with gr.Accordion('Closest Images'): | |
| # gr.Markdown("Finding the closest images in the dataset") | |
| # with gr.Row(): | |
| # with gr.Column(): | |
| # label_closest_image_0 = gr.Markdown('') | |
| # closest_image_0 = gr.Image(label='Closest Image',image_mode='contain',width=200, height=200) | |
| # with gr.Column(): | |
| # label_closest_image_1 = gr.Markdown('') | |
| # closest_image_1 = gr.Image(label='Second Closest Image',image_mode='contain',width=200, height=200) | |
| # with gr.Column(): | |
| # label_closest_image_2 = gr.Markdown('') | |
| # closest_image_2 = gr.Image(label='Third Closest Image',image_mode='contain',width=200, height=200) | |
| # with gr.Column(): | |
| # label_closest_image_3 = gr.Markdown('') | |
| # closest_image_3 = gr.Image(label='Forth Closest Image',image_mode='contain', width=200, height=200) | |
| # with gr.Column(): | |
| # label_closest_image_4 = gr.Markdown('') | |
| # closest_image_4 = gr.Image(label='Fifth Closest Image',image_mode='contain',width=200, height=200) | |
| # find_closest_btn = gr.Button("Find Closest Images") | |
| with gr.Accordion('Closest Fossil Images'): | |
| gr.Markdown("Finding the closest images in the dataset") | |
| with gr.Row(): | |
| closest_gallery = gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None) | |
| #.style(grid=[1, 5], height=200, width=200) | |
| find_closest_btn = gr.Button("Find Closest Images") | |
| #segment_button.click(segment_image, inputs=input_image, outputs=segmented_image) | |
| classify_image_button.click(classify_image, inputs=[original_image,model_name], outputs=class_predicted) | |
| # generate_exp.click(exp_image, inputs=[input_image,model_name,explain_method,sampling_size], outputs=[exp1,exp2,exp3,exp4,exp5]) # | |
| with gr.Accordion('Closest Leaves Images'): | |
| gr.Markdown("5 closest leaves") | |
| with gr.Accordion("Class Distribution of Closest Samples "): | |
| gr.Markdown("Visualize class distribution of top-k closest samples in our dataset") | |
| with gr.Column(): | |
| with gr.Row(): | |
| diagram= gr.Image(label = 'Bar Chart') | |
| generate_diagram = gr.Button("Generate Diagram") | |
| # with gr.Accordion("Using Diffuser"): | |
| # with gr.Column(): | |
| # prompt = gr.Textbox(lines=1, label="Prompt") | |
| # output_image = gr.Image(label="Output") | |
| # generate_button = gr.Button("Generate Leave") | |
| # with gr.Column(): | |
| # class_predicted2 = gr.Label(label='Class Predicted from diffuser') | |
| # classify_button = gr.Button("Classify Image") | |
| def update_exp_outputs(input_image,model_name,explain_method,nb_samples): | |
| labels, images = explain_image(input_image,model_name,explain_method,nb_samples) | |
| #labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels]) | |
| #labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>" | |
| image_caption=[] | |
| for i in range(5): | |
| image_caption.append((images[i],"Predicted Class "+str(i)+": "+labels[i])) | |
| return image_caption | |
| generate_explanations.click(fn=update_exp_outputs, inputs=[original_image,model_name,explain_method,sampling_size], outputs=[exp_gallery]) | |
| #find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4]) | |
| def update_closest_outputs(input_image,model_name): | |
| labels, images = find_closest(input_image,model_name) | |
| #labels_html = "".join([f'<div style="display: inline-block; text-align: center; width: 18%;">{label}</div>' for label in labels]) | |
| #labels_markdown = f"<div style='width: 100%; text-align: center;'>{labels_html}</div>" | |
| image_caption=[] | |
| for i in range(5): | |
| image_caption.append((images[i],labels[i])) | |
| return image_caption | |
| find_closest_btn.click(fn=update_closest_outputs, inputs=[original_image,model_name], outputs=[closest_gallery]) | |
| #classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted) | |
| generate_diagram.click(generate_diagram_closest, inputs=[original_image,model_name,top_k], outputs=diagram) | |
| process_button.click( | |
| fn=update_display, | |
| inputs=input_image, | |
| outputs=[original_image,input_image,workbench_image,instruction_text,model_name,explain_method,sampling_size,top_k,class_predicted,exp_gallery,closest_gallery,diagram] | |
| ) | |
| demo.queue() # manage multiple incoming requests | |
| if os.getenv('SYSTEM') == 'spaces': | |
| demo.launch(width='40%',auth=(os.environ.get('USERNAME'), os.environ.get('PASSWORD'))) | |
| else: | |
| demo.launch() | |