Spaces:
Sleeping
Sleeping
File size: 2,615 Bytes
68c77d8 e7f56c5 68c77d8 a6f03cc 68c77d8 c8db42d 68c77d8 0cddc48 68c77d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
import requests
from torchvision.models import vgg19
import gradio as gr
# Define preprocessing
preprocess = transforms.Compose([
transforms.Resize((224, 224)), # Resize images to 224x224
transforms.ToTensor(), # Convert images to tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize using ImageNet stats
])
# Load trained model
model = models.vgg19(weights='DEFAULT')
# Adjust the final fully connected layer for binary classification
num_ftrs = model.classifier[-1].in_features # Get the number of input features from the last layer
model.classifier[-1] = nn.Linear(num_ftrs, 2) # Replace with a new linear layer for binary classification
# Load the saved weights into the model
model.load_state_dict(torch.load('rice_plant_classification.pth', map_location=torch.device('cpu'))) # Ensure this file exists
model.eval()
# Define class labels
class_to_label = {0: 'Healthy', 1: 'Unhealthy'}
# Inference function
def predict(image):
# Preprocess the image
img = Image.fromarray(image)
img = preprocess(img).unsqueeze(0) # Add batch dimension
# Perform inference
with torch.no_grad():
output = model(img)
probabilities = torch.softmax(output, dim=1)
predicted_class = torch.argmax(probabilities, 1).item()
confidence = probabilities[0][predicted_class].item()
# Return the class label and confidence
return class_to_label[predicted_class], f'{confidence * 100:.2f}%'
example_images = ["healthy.jpg", "unhealthy.jpg"]
# Create Gradio interface
interface = gr.Interface(fn=predict,
inputs="image",
outputs=[gr.Textbox(label="Prediction"), gr.Textbox(label="Confidence")],
title="Sheath Rot Disease Detection in Rice",
description="This AI-powered interface utilizes a Convolutional Neural Network (CNN) model to detect sheath rot disease in rice plants. By analyzing uploaded images of rice crops, the model classifies them as either healthy or infected with sheath rot disease. This tool aims to support farmers and agronomists in early disease detection, allowing for timely intervention and improved crop management. Simply upload a rice plant image to get an instant diagnosis and help safeguard your yield from disease-related losses.",
examples=example_images
)
# Launch the app
if __name__ == "__main__":
interface.launch()
|