File size: 1,220 Bytes
f5b6d4a
 
 
 
 
 
 
00b9bf6
f5b6d4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import streamlit as st
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from resnet import Resnet50Flower102
import pandas as pd

st.title("Flower Image Classification")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Resnet50Flower102(device)
flowers_data = pd.read_csv("flowerdata.csv")
uploaded_file=st.file_uploader("Choose your file", type=["jpg", "png", "jpeg"])

model.load_state_dict(torch.load("model.pth", map_location=torch.device(device)))

transform_val = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
if uploaded_file is not None:
    image = Image.open(uploaded_file)
    img = transform_val(image)
    img = img.type(torch.FloatTensor).to(device)
    print(img.shape)
    img = img.unsqueeze(0)
    print(img.shape)

    with torch.no_grad():
        model.eval()
        flower = model(img)
        _, flower = flower.max(1)
    flower = flower[0].detach().cpu().numpy()
    flower_name = flowers_data["Name"][flower]
    st.header("Input Image")
    st.image(image=image, use_column_width=True)
    st.write("##", flower_name)