classification / app.py
srinuksv's picture
Update app.py
ff8351d verified
import gradio as gr
from keras.preprocessing import image
from keras.preprocessing.image import img_to_array
from keras.models import load_model
import numpy as np
# Load the pre-trained model from the local path
model_path = 'Mango.h5'
model = load_model(model_path)
def predict_disease(image_file, model, all_labels):
"""
Predict the disease from an image using the trained model.
Parameters:
- image_file: image, input image file
- model: Keras model, trained convolutional neural network
- all_labels: list, list of class labels
Returns:
- str, predicted class label
"""
try:
# Load and preprocess the image
img = image.load_img(image_file, target_size=(256, 256))
img_array = img_to_array(img)
img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
img_array = img_array / 255.0 # Normalize the image
# Predict the class
predictions = model.predict(img_array)
predicted_class = np.argmax(predictions[0])
# Return the class label
return all_labels[predicted_class]
except Exception as e:
print(f"Error: {e}")
return None
# List of class labels
all_labels = ['Mango Anthracrose','Mango Bacterial Cancker','Mango Cutting weevil','Mango Die Back','Mango Gall Midge','Mango Healthy','Mango powdery mildew','Mango Sooty Mould']
# Define the Gradio interface
def gradio_predict(image_file):
return predict_disease(image_file, model, all_labels)
# Create a Gradio interface
gr_interface = gr.Interface(
fn=gradio_predict, # Function to call for predictions
inputs=gr.Image(type="filepath"), # Upload image as file path
outputs="text", # Output will be the class label as text
title="Plant Disease Predictor",
description="Upload an image of a plant to predict the disease.",
)
# Launch the Gradio app
gr_interface.launch()