dwililiya's picture
Update app.py
a9f69a2 verified
import gradio as gr
from transformers import AutoModelForImageClassification, AutoConfig
from torchvision import transforms
from PIL import Image
import torch
# Load the model
MODEL_NAME = "dwililiya/sugarcane-plant-diseases-classification"
config = AutoConfig.from_pretrained(MODEL_NAME)
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME, config=config)
# Define a transform to prepare the image
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
])
# Define class names
class_names = ['Bacterial Blight', 'Healthy', 'Mosaic', 'Red Rot', 'Rust', 'Yellow']
def predict(image):
# Transform the image
image = transform(image).unsqueeze(0) # Add batch dimension
# Perform inference
with torch.no_grad():
outputs = model(image)
_, predicted = torch.max(outputs.logits, 1)
predicted_class = class_names[predicted.item()]
confidence = torch.softmax(outputs.logits, dim=1)[0][predicted].item()
return predicted_class, confidence
# Gradio interface
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Sugarcane Leaf Image"), # Change to 'pil'
outputs=[gr.Label(num_top_classes=1, label="Predicted Class"),
gr.Textbox(label="Confidence Score")],
title="Sugarcane Plant Diseases Classification",
description="Upload an image of a sugarcane leaf to classify its disease.",
)
if __name__ == "__main__":
iface.launch()