smishr-18 commited on
Commit
f5b6d4a
·
verified ·
1 Parent(s): 2485e25

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -0
app.py CHANGED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn as nn
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
+ from resnet import Resnet50Flower102
7
+ import pandas as pd
8
+ from dataloader import transform
9
+ st.title("Flower Image Classification")
10
+
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ model = Resnet50Flower102(device)
13
+ flowers_data = pd.read_csv("flowerdata.csv")
14
+ uploaded_file=st.file_uploader("Choose your file", type=["jpg", "png", "jpeg"])
15
+
16
+ model.load_state_dict(torch.load("model.pth", map_location=torch.device(device)))
17
+
18
+ transform_val = transforms.Compose([
19
+ transforms.Resize((224, 224)),
20
+ transforms.ToTensor(),
21
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
22
+ if uploaded_file is not None:
23
+ image = Image.open(uploaded_file)
24
+ img = transform_val(image)
25
+ img = img.type(torch.FloatTensor).to(device)
26
+ print(img.shape)
27
+ img = img.unsqueeze(0)
28
+ print(img.shape)
29
+
30
+ with torch.no_grad():
31
+ model.eval()
32
+ flower = model(img)
33
+ _, flower = flower.max(1)
34
+ flower = flower[0].detach().cpu().numpy()
35
+ flower_name = flowers_data["Name"][flower]
36
+ st.header("Input Image")
37
+ st.image(image=image, use_column_width=True)
38
+ st.write("##", flower_name)