File size: 1,252 Bytes
f5b6d4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import streamlit as st
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
from resnet import Resnet50Flower102
import pandas as pd
from dataloader import transform
st.title("Flower Image Classification")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Resnet50Flower102(device)
flowers_data = pd.read_csv("flowerdata.csv")
uploaded_file=st.file_uploader("Choose your file", type=["jpg", "png", "jpeg"])

model.load_state_dict(torch.load("model.pth", map_location=torch.device(device)))

transform_val = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
if uploaded_file is not None:
    image = Image.open(uploaded_file)
    img = transform_val(image)
    img = img.type(torch.FloatTensor).to(device)
    print(img.shape)
    img = img.unsqueeze(0)
    print(img.shape)

    with torch.no_grad():
        model.eval()
        flower = model(img)
        _, flower = flower.max(1)
    flower = flower[0].detach().cpu().numpy()
    flower_name = flowers_data["Name"][flower]
    st.header("Input Image")
    st.image(image=image, use_column_width=True)
    st.write("##", flower_name)