smishr-18's picture
Upload app.py
f5b6d4a verified
raw
history blame
1.25 kB
import streamlit as st
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from resnet import Resnet50Flower102
import pandas as pd
from dataloader import transform
st.title("Flower Image Classification")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Resnet50Flower102(device)
flowers_data = pd.read_csv("flowerdata.csv")
uploaded_file=st.file_uploader("Choose your file", type=["jpg", "png", "jpeg"])
model.load_state_dict(torch.load("model.pth", map_location=torch.device(device)))
transform_val = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
if uploaded_file is not None:
image = Image.open(uploaded_file)
img = transform_val(image)
img = img.type(torch.FloatTensor).to(device)
print(img.shape)
img = img.unsqueeze(0)
print(img.shape)
with torch.no_grad():
model.eval()
flower = model(img)
_, flower = flower.max(1)
flower = flower[0].detach().cpu().numpy()
flower_name = flowers_data["Name"][flower]
st.header("Input Image")
st.image(image=image, use_column_width=True)
st.write("##", flower_name)