File size: 2,234 Bytes
f6be055
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import streamlit as st
import keras_hub
from PIL import Image
import numpy as np

classification_models = {
    "ResNet18": "resnet_18_imagenet",
    "ResNet50": "resnet_50_imagenet"
}


def load_preprocessor(model_name):
    return keras_hub.models.ImageClassifierPreprocessor.from_preset(model_name)

def load_model(model_name):
    """Load a pre-trained model from KerasHub."""
    return keras_hub.models.ImageClassifier.from_preset(model_name)

def upload_image():
    """Common function for uploading an image."""
    uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
    if uploaded_file:
        image = Image.open(uploaded_file)
        return np.expand_dims(np.array(image).astype("float32"), axis=0)
    return None

def vision_page():
    st.header("Vision Models")
    st.write("Explore Vision Models including Image Classification, Object Detection, and Segmentation.")

    # Tabs for different vision tasks
    tab1, tab2, tab3 = st.tabs(["Image Classification", "Object Detection", "Segmentation"])

    with tab1:
        st.subheader("Image Classification")
        model_name = st.selectbox("Choose a pre-trained model:", list(classification_models.keys()))
        preprocessor = load_preprocessor(classification_models[model_name])
        model = load_model(classification_models[model_name])

        image = upload_image()
        if image is not None:
            preprocessed_image = preprocessor(image)
            raw_predictions = model(preprocessed_image)
            predictions = keras_hub.utils.decode_imagenet_predictions(raw_predictions)

            col1, col2 = st.columns([1, 1])
            with col1:
                st.image(image[0].astype("uint8"), caption="Uploaded Image", use_container_width=True)
            with col2:
                st.write("##### Top Predictions:")
                for idx, (class_name, score) in enumerate(predictions[0]):
                    st.write(f"{idx + 1}: {class_name}")

    with tab2:
        st.subheader("Object Detection")
        st.write("Object Detection functionality is under development.")

    with tab3:
        st.subheader("Segmentation")
        st.write("Segmentation functionality is under development.")