izeeek commited on
Commit
68c77d8
·
verified ·
1 Parent(s): 370f486

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import requests
6
+ from torchvision.models import vgg19
7
+ import gradio as gr
8
+
9
+ # Define preprocessing
10
+ preprocess = transforms.Compose([
11
+ transforms.Resize((224, 224)), # Resize images to 224x224
12
+ transforms.ToTensor(), # Convert images to tensor
13
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize using ImageNet stats
14
+ ])
15
+
16
+ # Load trained model
17
+ model = models.vgg19(weights='DEFAULT')
18
+
19
+ # Adjust the final fully connected layer for binary classification
20
+ num_ftrs = model.classifier[-1].in_features # Get the number of input features from the last layer
21
+ model.classifier[-1] = nn.Linear(num_ftrs, 2) # Replace with a new linear layer for binary classification
22
+
23
+ # Load the saved weights into the model
24
+ model.load_state_dict(torch.load('rice_plant_classification.pth', weights_only=True)) # Ensure this file exists
25
+ model.eval()
26
+
27
+ # Define class labels
28
+ class_to_label = {0: 'Healthy', 1: 'Unhealthy'}
29
+
30
+ # Inference function
31
+ def predict(image):
32
+ # Preprocess the image
33
+ img = Image.fromarray(image)
34
+ img = preprocess(img).unsqueeze(0) # Add batch dimension
35
+
36
+ # Perform inference
37
+ with torch.no_grad():
38
+ output = model(img)
39
+ probabilities = torch.softmax(output, dim=1)
40
+ predicted_class = torch.argmax(probabilities, 1).item()
41
+ confidence = probabilities[0][predicted_class].item()
42
+
43
+ # Return the class label and confidence
44
+ return class_to_label[predicted_class], f'{confidence * 100:.2f}%'
45
+
46
+ example_images = ["healthy.jpeg", "unhealthy.jpeg"]
47
+
48
+ # Create Gradio interface
49
+ interface = gr.Interface(fn=predict,
50
+ inputs="image",
51
+ outputs=[gr.Textbox(label="Prediction"), gr.Textbox(label="Confidence")],
52
+ title="Healthy vs Unhealthy Rice Plant Classifier",
53
+ description="Upload a rice plant image to classify either it is healthy or unhealthy.",
54
+ examples=example_images
55
+ )
56
+
57
+ # Launch the app
58
+ if __name__ == "__main__":
59
+ interface.launch()