Spaces:
Runtime error
Runtime error
| import utils | |
| from huggingface_hub.keras_mixin import from_pretrained_keras | |
| from PIL import Image | |
| import streamlit as st | |
| import tensorflow as tf | |
| st.cache(show_spinner=True) | |
| def load_model(): | |
| # Load the DINO model | |
| dino = from_pretrained_keras("probing-vits/vit-dino-base16") | |
| return dino | |
| dino=load_model() | |
| # Inputs | |
| st.title("Input your image") | |
| image_url = st.text_input( | |
| label="URL of image", | |
| value="https://dl.fbaipublicfiles.com/dino/img.png", | |
| placeholder="https://your-favourite-image.png" | |
| ) | |
| uploaded_file = st.file_uploader("or an image file", type =["jpg","jpeg"]) | |
| # Outputs | |
| st.title("Original Image from URL") | |
| # Preprocess the same image but with normlization. | |
| image, preprocessed_image = utils.load_image_from_url( | |
| image_url, | |
| model_type="dino" | |
| ) | |
| if uploaded_file: | |
| image = Image.open(uploaded_file) | |
| preprocessed_image = utils.preprocess_image(image, "dino") | |
| st.image(image, caption="Original Image") | |
| with st.spinner("Generating the attention scores..."): | |
| # Get the attention scores | |
| _, attention_score_dict = dino.predict(preprocessed_image) | |
| with st.spinner("Generating the heat maps... HOLD ON!"): | |
| # De-normalize the image for visual clarity. | |
| in1k_mean = tf.constant([0.485 * 255, 0.456 * 255, 0.406 * 255]) | |
| in1k_std = tf.constant([0.229 * 255, 0.224 * 255, 0.225 * 255]) | |
| preprocessed_img_orig = (preprocessed_image * in1k_std) + in1k_mean | |
| preprocessed_img_orig = preprocessed_img_orig / 255. | |
| preprocessed_img_orig = tf.clip_by_value(preprocessed_img_orig, 0.0, 1.0).numpy() | |
| attentions = utils.attention_heatmap( | |
| attention_score_dict=attention_score_dict, | |
| image=preprocessed_img_orig | |
| ) | |
| utils.plot(attentions=attentions, image=preprocessed_img_orig) | |
| # Show the attention maps | |
| st.title("Attention 🔥 Maps") | |
| image = Image.open("heat_map.png") | |
| st.image(image, caption="Attention Heat Maps") |