import streamlit as st import torch from torchvision import models, transforms from PIL import Image # Load a pre-trained ResNet model model = models.resnet50(pretrained=True) model.eval() # Define the transformations for the input image transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def classify_image(image): # Preprocess the input image image = transform(image) image = image.unsqueeze(0) # Add batch dimension # Make a prediction with torch.no_grad(): output = model(image) # Get the predicted class _, predicted_class = torch.max(output, 1) return predicted_class.item() def main(): st.title("Image Classification with PyTorch and Streamlit") uploaded_file = st.file_uploader("Choose a file", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # Display the uploaded image image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image.", use_column_width=True) # Make a prediction class_idx = classify_image(image) # Display the result class_label = str(class_idx) st.write("Class Prediction: ", class_label) if __name__ == "__main__": main()