Spaces:
Running
on
T4
Running
on
T4
| import os | |
| import sys | |
| from env import config_env | |
| config_env() | |
| import gradio as gr | |
| from huggingface_hub import snapshot_download | |
| import cv2 | |
| import dotenv | |
| dotenv.load_dotenv() | |
| import numpy as np | |
| import gradio as gr | |
| import glob | |
| from inference_sam import segmentation_sam | |
| from explanations import explain | |
| from inference_resnet import get_triplet_model | |
| import pathlib | |
| import tensorflow as tf | |
| from closest_sample import get_images | |
| if not os.path.exists('images'): | |
| REPO_ID='Serrelab/image_examples_gradio' | |
| snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='dataset',local_dir='images') | |
| def get_model(model_name): | |
| if model_name=='Mummified 170': | |
| n_classes = 170 | |
| model = get_triplet_model(input_shape = (600, 600, 3), | |
| embedding_units = 256, | |
| embedding_depth = 2, | |
| backbone_class=tf.keras.applications.ResNet50V2, | |
| nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2') | |
| model.load_weights('model_classification/mummified-170.h5') | |
| elif model_name=='Rock 170': | |
| n_classes = 171 | |
| model = get_triplet_model(input_shape = (600, 600, 3), | |
| embedding_units = 256, | |
| embedding_depth = 2, | |
| backbone_class=tf.keras.applications.ResNet50V2, | |
| nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2') | |
| model.load_weights('model_classification/rock-170.h5') | |
| else: | |
| return 'Error' | |
| return model,n_classes | |
| def segment_image(input_image): | |
| img = segmentation_sam(input_image) | |
| return img | |
| def classify_image(input_image, model_name): | |
| if 'Rock 170' ==model_name: | |
| from inference_resnet import inference_resnet_finer | |
| model,n_classes= get_model(model_name) | |
| result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes) | |
| return result | |
| elif 'Mummified 170' ==model_name: | |
| from inference_resnet import inference_resnet_finer | |
| model, n_classes= get_model(model_name) | |
| result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes) | |
| return result | |
| if 'Fossils 19' ==model_name: | |
| from inference_beit import inference_dino | |
| model,n_classes = get_model(model_name) | |
| return inference_dino(input_image,model_name) | |
| return None | |
| def get_embeddings(input_image,model_name): | |
| if 'Rock 170' ==model_name: | |
| from inference_resnet import inference_resnet_embedding | |
| model,n_classes= get_model(model_name) | |
| result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes) | |
| return result | |
| elif 'Mummified 170' ==model_name: | |
| from inference_resnet import inference_resnet_embedding | |
| model, n_classes= get_model(model_name) | |
| result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes) | |
| return result | |
| if 'Fossils 19' ==model_name: | |
| from inference_beit import inference_dino | |
| model,n_classes = get_model(model_name) | |
| return inference_dino(input_image,model_name) | |
| return None | |
| def find_closest(input_image,model_name): | |
| embedding = get_embeddings(input_image,model_name) | |
| paths = get_images(embedding) | |
| return paths | |
| def explain_image(input_image,model_name): | |
| model,n_classes= get_model(model_name) | |
| saliency, integrated, smoothgrad = explain(model,input_image,n_classes=n_classes) | |
| #original = saliency + integrated + smoothgrad | |
| print('done') | |
| return saliency, integrated, smoothgrad, | |
| with gr.Blocks(theme='sudeepshouche/minimalist') as demo: | |
| with gr.Tab(" Florrissant Fossils"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input") | |
| classify_image_button = gr.Button("Classify Image") | |
| with gr.Column(): | |
| #segmented_image = gr.outputs.Image(label="SAM output",type='numpy') | |
| segmented_image=gr.Image(label="Segmented Image", type='numpy') | |
| segment_button = gr.Button("Segment Image") | |
| #classify_segmented_button = gr.Button("Classify Segmented Image") | |
| with gr.Column(): | |
| model_name = gr.Dropdown( | |
| ["Mummified 170", "Rock 170"], | |
| multiselect=False, | |
| value="Rock 170", | |
| label="Model", | |
| interactive=True, | |
| ) | |
| class_predicted = gr.Label(label='Class Predicted',num_top_classes=10) | |
| with gr.Row(): | |
| paths = sorted(pathlib.Path('images/').rglob('*.jpg')) | |
| samples=[[path.as_posix()] for path in paths if 'fossils' in str(path) ][:19] | |
| examples_fossils = gr.Examples(samples, inputs=input_image,examples_per_page=10,label='Fossils Examples from the dataset') | |
| samples=[[path.as_posix()] for path in paths if 'leaves' in str(path) ][:19] | |
| examples_leaves = gr.Examples(samples, inputs=input_image,examples_per_page=5,label='Leaves Examples from the dataset') | |
| # with gr.Accordion("Using Diffuser"): | |
| # with gr.Column(): | |
| # prompt = gr.Textbox(lines=1, label="Prompt") | |
| # output_image = gr.Image(label="Output") | |
| # generate_button = gr.Button("Generate Leave") | |
| # with gr.Column(): | |
| # class_predicted2 = gr.Label(label='Class Predicted from diffuser') | |
| # classify_button = gr.Button("Classify Image") | |
| with gr.Accordion("Explanations "): | |
| gr.Markdown("Computing Explanations from the model") | |
| with gr.Row(): | |
| #original_input = gr.Image(label="Original Frame") | |
| saliency = gr.Image(label="saliency") | |
| gradcam = gr.Image(label='integraged gradients') | |
| guided_gradcam = gr.Image(label='gradcam') | |
| #guided_backprop = gr.Image(label='guided backprop') | |
| generate_explanations = gr.Button("Generate Explanations") | |
| with gr.Accordion('Closest Images'): | |
| gr.Markdown("Finding the closest images in the dataset") | |
| with gr.Row(): | |
| closest_image_0 = gr.Image(label='Closest Image') | |
| closest_image_1 = gr.Image(label='Second Closest Image') | |
| closest_image_2 = gr.Image(label='Third Closest Image') | |
| closest_image_3 = gr.Image(label='Forth Closest Image') | |
| closest_image_4 = gr.Image(label='Fifth Closest Image') | |
| find_closest_btn = gr.Button("Find Closest Images") | |
| segment_button.click(segment_image, inputs=input_image, outputs=segmented_image) | |
| classify_image_button.click(classify_image, inputs=[input_image,model_name], outputs=class_predicted) | |
| generate_explanations.click(explain_image, inputs=[input_image,model_name], outputs=[saliency,gradcam,guided_gradcam]) | |
| find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4]) | |
| #classify_segmented_button.click(classify_image, inputs=[segmented_image,model_name], outputs=class_predicted) | |
| demo.queue() | |
| if os.getenv('SYSTEM') == 'spaces': | |
| demo.launch(width='40%',auth=(os.environ.get('USERNAME'), os.environ.get('PASSWORD'))) | |
| else: | |
| demo.launch() | |