Spaces:
Sleeping
Sleeping
import streamlit as st | |
import keras_hub | |
from PIL import Image | |
import numpy as np | |
classification_models = { | |
"ResNet18": "resnet_18_imagenet", | |
"ResNet50": "resnet_50_imagenet" | |
} | |
def load_preprocessor(model_name): | |
return keras_hub.models.ImageClassifierPreprocessor.from_preset(model_name) | |
def load_model(model_name): | |
"""Load a pre-trained model from KerasHub.""" | |
return keras_hub.models.ImageClassifier.from_preset(model_name) | |
def upload_image(): | |
"""Common function for uploading an image.""" | |
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
if uploaded_file: | |
image = Image.open(uploaded_file) | |
return np.expand_dims(np.array(image).astype("float32"), axis=0) | |
return None | |
def vision_page(): | |
st.header("Vision Models") | |
st.write("Explore Vision Models including Image Classification, Object Detection, and Segmentation.") | |
# Tabs for different vision tasks | |
tab1, tab2, tab3 = st.tabs(["Image Classification", "Object Detection", "Segmentation"]) | |
with tab1: | |
st.subheader("Image Classification") | |
model_name = st.selectbox("Choose a pre-trained model:", list(classification_models.keys())) | |
preprocessor = load_preprocessor(classification_models[model_name]) | |
model = load_model(classification_models[model_name]) | |
image = upload_image() | |
if image is not None: | |
preprocessed_image = preprocessor(image) | |
raw_predictions = model(preprocessed_image) | |
predictions = keras_hub.utils.decode_imagenet_predictions(raw_predictions) | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
st.image(image[0].astype("uint8"), caption="Uploaded Image", use_container_width=True) | |
with col2: | |
st.write("##### Top Predictions:") | |
for idx, (class_name, score) in enumerate(predictions[0]): | |
st.write(f"{idx + 1}: {class_name}") | |
with tab2: | |
st.subheader("Object Detection") | |
st.write("Object Detection functionality is under development.") | |
with tab3: | |
st.subheader("Segmentation") | |
st.write("Segmentation functionality is under development.") | |