IrisFlower / app.py
PebinAPJ's picture
Update app.py
07fb3b2 verified
raw
history blame
1.03 kB
# app.py
import gradio as gr
import pandas as pd
import joblib
# Load pre-trained model and dataset
data = pd.read_csv('Iris.csv')
data.drop(columns=['Id'], inplace=True)
# Load the saved model
model = joblib.load('best_random_forest_model.pkl')
def classify_iris(sepal_length, sepal_width, petal_length, petal_width):
"""Classify iris species based on input features."""
input_features = [[sepal_length, sepal_width, petal_length, petal_width]]
prediction = model.predict(input_features)[0]
return prediction
# Define the Gradio interface
inputs = [
gr.Number(label="Sepal Length (cm)"),
gr.Number(label="Sepal Width (cm)"),
gr.Number(label="Petal Length (cm)"),
gr.Number(label="Petal Width (cm)")
]
outputs = gr.Textbox(label="Predicted Iris Species")
description = "This app classifies iris species (Setosa, Versicolor, Virginica) based on the given features."
gr.Interface(fn=classify_iris, inputs=inputs, outputs=outputs, title="Iris Species Classifier", description=description).launch()