Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torchvision import transforms, models | |
from PIL import Image | |
import requests | |
from torchvision.models import vgg19 | |
import gradio as gr | |
# Define preprocessing | |
preprocess = transforms.Compose([ | |
transforms.Resize((224, 224)), # Resize images to 224x224 | |
transforms.ToTensor(), # Convert images to tensor | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize using ImageNet stats | |
]) | |
# Load trained model | |
model = models.vgg19(weights='DEFAULT') | |
# Adjust the final fully connected layer for binary classification | |
num_ftrs = model.classifier[-1].in_features # Get the number of input features from the last layer | |
model.classifier[-1] = nn.Linear(num_ftrs, 2) # Replace with a new linear layer for binary classification | |
# Load the saved weights into the model | |
model.load_state_dict(torch.load('rice_plant_classification.pth', map_location=torch.device('cpu'))) # Ensure this file exists | |
model.eval() | |
# Define class labels | |
class_to_label = {0: 'Healthy', 1: 'Unhealthy'} | |
# Inference function | |
def predict(image): | |
# Preprocess the image | |
img = Image.fromarray(image) | |
img = preprocess(img).unsqueeze(0) # Add batch dimension | |
# Perform inference | |
with torch.no_grad(): | |
output = model(img) | |
probabilities = torch.softmax(output, dim=1) | |
predicted_class = torch.argmax(probabilities, 1).item() | |
confidence = probabilities[0][predicted_class].item() | |
# Return the class label and confidence | |
return class_to_label[predicted_class], f'{confidence * 100:.2f}%' | |
example_images = ["healthy.jpg", "unhealthy.jpg"] | |
# Create Gradio interface | |
interface = gr.Interface(fn=predict, | |
inputs="image", | |
outputs=[gr.Textbox(label="Prediction"), gr.Textbox(label="Confidence")], | |
title="Sheath Rot Disease Detection in Rice", | |
description="This AI-powered interface utilizes a Convolutional Neural Network (CNN) model to detect sheath rot disease in rice plants. By analyzing uploaded images of rice crops, the model classifies them as either healthy or infected with sheath rot disease. This tool aims to support farmers and agronomists in early disease detection, allowing for timely intervention and improved crop management. Simply upload a rice plant image to get an instant diagnosis and help safeguard your yield from disease-related losses.", | |
examples=example_images | |
) | |
# Launch the app | |
if __name__ == "__main__": | |
interface.launch() | |