ModelLens / vision_models.py
sravanneeli
base version
f6be055
raw
history blame
2.23 kB
import streamlit as st
import keras_hub
from PIL import Image
import numpy as np
classification_models = {
"ResNet18": "resnet_18_imagenet",
"ResNet50": "resnet_50_imagenet"
}
def load_preprocessor(model_name):
return keras_hub.models.ImageClassifierPreprocessor.from_preset(model_name)
def load_model(model_name):
"""Load a pre-trained model from KerasHub."""
return keras_hub.models.ImageClassifier.from_preset(model_name)
def upload_image():
"""Common function for uploading an image."""
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if uploaded_file:
image = Image.open(uploaded_file)
return np.expand_dims(np.array(image).astype("float32"), axis=0)
return None
def vision_page():
st.header("Vision Models")
st.write("Explore Vision Models including Image Classification, Object Detection, and Segmentation.")
# Tabs for different vision tasks
tab1, tab2, tab3 = st.tabs(["Image Classification", "Object Detection", "Segmentation"])
with tab1:
st.subheader("Image Classification")
model_name = st.selectbox("Choose a pre-trained model:", list(classification_models.keys()))
preprocessor = load_preprocessor(classification_models[model_name])
model = load_model(classification_models[model_name])
image = upload_image()
if image is not None:
preprocessed_image = preprocessor(image)
raw_predictions = model(preprocessed_image)
predictions = keras_hub.utils.decode_imagenet_predictions(raw_predictions)
col1, col2 = st.columns([1, 1])
with col1:
st.image(image[0].astype("uint8"), caption="Uploaded Image", use_container_width=True)
with col2:
st.write("##### Top Predictions:")
for idx, (class_name, score) in enumerate(predictions[0]):
st.write(f"{idx + 1}: {class_name}")
with tab2:
st.subheader("Object Detection")
st.write("Object Detection functionality is under development.")
with tab3:
st.subheader("Segmentation")
st.write("Segmentation functionality is under development.")