|
import gradio as gr |
|
from transformers import AutoModelForImageClassification, AutoConfig |
|
from torchvision import transforms |
|
from PIL import Image |
|
import torch |
|
|
|
|
|
MODEL_NAME = "dwililiya/sugarcane-plant-diseases-classification" |
|
config = AutoConfig.from_pretrained(MODEL_NAME) |
|
model = AutoModelForImageClassification.from_pretrained(MODEL_NAME, config=config) |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
]) |
|
|
|
|
|
class_names = ['Bacterial Blight', 'Healthy', 'Mosaic', 'Red Rot', 'Rust', 'Yellow'] |
|
|
|
def predict(image): |
|
|
|
image = transform(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model(image) |
|
_, predicted = torch.max(outputs.logits, 1) |
|
predicted_class = class_names[predicted.item()] |
|
confidence = torch.softmax(outputs.logits, dim=1)[0][predicted].item() |
|
|
|
return predicted_class, confidence |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil", label="Upload Sugarcane Leaf Image"), |
|
outputs=[gr.Label(num_top_classes=1, label="Predicted Class"), |
|
gr.Textbox(label="Confidence Score")], |
|
title="Sugarcane Plant Diseases Classification", |
|
description="Upload an image of a sugarcane leaf to classify its disease.", |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|