import torch import torch.nn as nn from torchvision import transforms, models from PIL import Image import requests from torchvision.models import vgg19 import gradio as gr # Define preprocessing preprocess = transforms.Compose([ transforms.Resize((224, 224)), # Resize images to 224x224 transforms.ToTensor(), # Convert images to tensor transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize using ImageNet stats ]) # Load trained model model = models.vgg19(weights='DEFAULT') # Adjust the final fully connected layer for binary classification num_ftrs = model.classifier[-1].in_features # Get the number of input features from the last layer model.classifier[-1] = nn.Linear(num_ftrs, 2) # Replace with a new linear layer for binary classification # Load the saved weights into the model model.load_state_dict(torch.load('rice_plant_classification.pth', map_location=torch.device('cpu'))) # Ensure this file exists model.eval() # Define class labels class_to_label = {0: 'Healthy', 1: 'Unhealthy'} # Inference function def predict(image): # Preprocess the image img = Image.fromarray(image) img = preprocess(img).unsqueeze(0) # Add batch dimension # Perform inference with torch.no_grad(): output = model(img) probabilities = torch.softmax(output, dim=1) predicted_class = torch.argmax(probabilities, 1).item() confidence = probabilities[0][predicted_class].item() # Return the class label and confidence return class_to_label[predicted_class], f'{confidence * 100:.2f}%' example_images = ["healthy.jpg", "unhealthy.jpg"] # Create Gradio interface interface = gr.Interface(fn=predict, inputs="image", outputs=[gr.Textbox(label="Prediction"), gr.Textbox(label="Confidence")], title="Sheath Rot Disease Detection in Rice", description="This AI-powered interface utilizes a Convolutional Neural Network (CNN) model to detect sheath rot disease in rice plants. By analyzing uploaded images of rice crops, the model classifies them as either healthy or infected with sheath rot disease. This tool aims to support farmers and agronomists in early disease detection, allowing for timely intervention and improved crop management. Simply upload a rice plant image to get an instant diagnosis and help safeguard your yield from disease-related losses.", examples=example_images ) # Launch the app if __name__ == "__main__": interface.launch()