import streamlit as st import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image from resnet import Resnet50Flower102 import pandas as pd st.title("Flower Image Classification") device = 'cuda' if torch.cuda.is_available() else 'cpu' model = Resnet50Flower102(device) flowers_data = pd.read_csv("flowerdata.csv") uploaded_file=st.file_uploader("Choose your file", type=["jpg", "png", "jpeg"]) model.load_state_dict(torch.load("model.pth", map_location=torch.device(device))) transform_val = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) if uploaded_file is not None: image = Image.open(uploaded_file) img = transform_val(image) img = img.type(torch.FloatTensor).to(device) print(img.shape) img = img.unsqueeze(0) print(img.shape) with torch.no_grad(): model.eval() flower = model(img) _, flower = flower.max(1) flower = flower[0].detach().cpu().numpy() flower_name = flowers_data["Name"][flower] st.header("Input Image") st.image(image=image, use_column_width=True) st.write("##", flower_name)