PebinAPJ commited on
Commit
07fb3b2
·
verified ·
1 Parent(s): f1f4bb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -18
app.py CHANGED
@@ -1,37 +1,30 @@
1
  # app.py
2
  import gradio as gr
3
  import pandas as pd
4
- from sklearn.ensemble import RandomForestClassifier
5
- from sklearn.model_selection import train_test_split
6
- from sklearn.metrics import accuracy_score
7
 
8
- # Load the dataset
9
  data = pd.read_csv('Iris.csv')
 
10
 
11
- # Prepare features and target
12
- X = data.iloc[:, 1:-1] # Assuming the first column is an ID and last is the target
13
- y = data.iloc[:, -1]
14
-
15
- # Train a model
16
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
17
- model = RandomForestClassifier(random_state=42)
18
- model.fit(X_train, y_train)
19
 
20
  def classify_iris(sepal_length, sepal_width, petal_length, petal_width):
21
- """Function to classify iris species based on input features."""
22
  input_features = [[sepal_length, sepal_width, petal_length, petal_width]]
23
  prediction = model.predict(input_features)[0]
24
  return prediction
25
 
26
  # Define the Gradio interface
27
  inputs = [
28
- gr.inputs.Number(label="Sepal Length (cm)"),
29
- gr.inputs.Number(label="Sepal Width (cm)"),
30
- gr.inputs.Number(label="Petal Length (cm)"),
31
- gr.inputs.Number(label="Petal Width (cm)")
32
  ]
33
 
34
- outputs = gr.outputs.Textbox(label="Predicted Iris Species")
35
 
36
  description = "This app classifies iris species (Setosa, Versicolor, Virginica) based on the given features."
37
 
 
1
  # app.py
2
  import gradio as gr
3
  import pandas as pd
4
+ import joblib
 
 
5
 
6
+ # Load pre-trained model and dataset
7
  data = pd.read_csv('Iris.csv')
8
+ data.drop(columns=['Id'], inplace=True)
9
 
10
+ # Load the saved model
11
+ model = joblib.load('best_random_forest_model.pkl')
 
 
 
 
 
 
12
 
13
  def classify_iris(sepal_length, sepal_width, petal_length, petal_width):
14
+ """Classify iris species based on input features."""
15
  input_features = [[sepal_length, sepal_width, petal_length, petal_width]]
16
  prediction = model.predict(input_features)[0]
17
  return prediction
18
 
19
  # Define the Gradio interface
20
  inputs = [
21
+ gr.Number(label="Sepal Length (cm)"),
22
+ gr.Number(label="Sepal Width (cm)"),
23
+ gr.Number(label="Petal Length (cm)"),
24
+ gr.Number(label="Petal Width (cm)")
25
  ]
26
 
27
+ outputs = gr.Textbox(label="Predicted Iris Species")
28
 
29
  description = "This app classifies iris species (Setosa, Versicolor, Virginica) based on the given features."
30