|
import streamlit as st |
|
import torch |
|
from torchvision import models, transforms |
|
from PIL import Image |
|
|
|
|
|
model = models.resnet50(pretrained=True) |
|
model.eval() |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
]) |
|
|
|
def classify_image(image): |
|
|
|
image = transform(image) |
|
image = image.unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(image) |
|
|
|
|
|
_, predicted_class = torch.max(output, 1) |
|
|
|
return predicted_class.item() |
|
|
|
def main(): |
|
st.title("Image Classification with PyTorch and Streamlit") |
|
|
|
uploaded_file = st.file_uploader("Choose a file", type=["jpg", "jpeg", "png"]) |
|
|
|
if uploaded_file is not None: |
|
|
|
image = Image.open(uploaded_file) |
|
st.image(image, caption="Uploaded Image.", use_column_width=True) |
|
|
|
|
|
class_idx = classify_image(image) |
|
|
|
|
|
class_label = str(class_idx) |
|
st.write("Class Prediction: ", class_label) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|