Spaces:
Sleeping
Sleeping
File size: 1,220 Bytes
f5b6d4a 00b9bf6 f5b6d4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import streamlit as st
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from resnet import Resnet50Flower102
import pandas as pd
st.title("Flower Image Classification")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Resnet50Flower102(device)
flowers_data = pd.read_csv("flowerdata.csv")
uploaded_file=st.file_uploader("Choose your file", type=["jpg", "png", "jpeg"])
model.load_state_dict(torch.load("model.pth", map_location=torch.device(device)))
transform_val = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
if uploaded_file is not None:
image = Image.open(uploaded_file)
img = transform_val(image)
img = img.type(torch.FloatTensor).to(device)
print(img.shape)
img = img.unsqueeze(0)
print(img.shape)
with torch.no_grad():
model.eval()
flower = model(img)
_, flower = flower.max(1)
flower = flower[0].detach().cpu().numpy()
flower_name = flowers_data["Name"][flower]
st.header("Input Image")
st.image(image=image, use_column_width=True)
st.write("##", flower_name) |