izeeek's picture
Update app.py
0cddc48 verified
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import requests
from torchvision.models import vgg19
import gradio as gr
# Define preprocessing
preprocess = transforms.Compose([
transforms.Resize((224, 224)), # Resize images to 224x224
transforms.ToTensor(), # Convert images to tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize using ImageNet stats
])
# Load trained model
model = models.vgg19(weights='DEFAULT')
# Adjust the final fully connected layer for binary classification
num_ftrs = model.classifier[-1].in_features # Get the number of input features from the last layer
model.classifier[-1] = nn.Linear(num_ftrs, 2) # Replace with a new linear layer for binary classification
# Load the saved weights into the model
model.load_state_dict(torch.load('rice_plant_classification.pth', map_location=torch.device('cpu'))) # Ensure this file exists
model.eval()
# Define class labels
class_to_label = {0: 'Healthy', 1: 'Unhealthy'}
# Inference function
def predict(image):
# Preprocess the image
img = Image.fromarray(image)
img = preprocess(img).unsqueeze(0) # Add batch dimension
# Perform inference
with torch.no_grad():
output = model(img)
probabilities = torch.softmax(output, dim=1)
predicted_class = torch.argmax(probabilities, 1).item()
confidence = probabilities[0][predicted_class].item()
# Return the class label and confidence
return class_to_label[predicted_class], f'{confidence * 100:.2f}%'
example_images = ["healthy.jpg", "unhealthy.jpg"]
# Create Gradio interface
interface = gr.Interface(fn=predict,
inputs="image",
outputs=[gr.Textbox(label="Prediction"), gr.Textbox(label="Confidence")],
title="Sheath Rot Disease Detection in Rice",
description="This AI-powered interface utilizes a Convolutional Neural Network (CNN) model to detect sheath rot disease in rice plants. By analyzing uploaded images of rice crops, the model classifies them as either healthy or infected with sheath rot disease. This tool aims to support farmers and agronomists in early disease detection, allowing for timely intervention and improved crop management. Simply upload a rice plant image to get an instant diagnosis and help safeguard your yield from disease-related losses.",
examples=example_images
)
# Launch the app
if __name__ == "__main__":
interface.launch()