iris / app.py
SevenhuijsenM
Iris prediction
bc31129
raw
history blame
1.16 kB
import gradio as gr
import hopsworks
import joblib
import pandas as pd
import numpy as np
project = hopsworks.login()
fs = project.get_feature_store()
mr = project.get_model_registry()
model = mr.get_model("iris_model", version=1)
model_dir = model.download()
model = joblib.load(model_dir + "/iris_model.pkl")
feature_view = fs.get_feature_view(name="iris", version=1)
batch_data = feature_view.get_batch_data()
def greet(sep_length, sep_width, pet_length, pet_width):
df = pd.DataFrame({ "sepal_length": sep_length, "sepal_width": sep_width, "petal_length": pet_length, "petal_width": pet_width}, index=[0])
prediction = model.predict(df)[0]
return prediction, f"Images\{prediction}.jpg"
predict = gr.Interface(
fn=greet,
inputs=[
gr.Textbox(placeholder="Sepal length here", label = "Sepal Length"),
gr.Textbox(placeholder="Sepal width here", label = "Sepal Width"),
gr.Textbox(placeholder="Petal length here", label = "Petal Length"),
gr.Textbox(placeholder="Petal width here", label = "Petal Width")],
outputs=[gr.Textbox(label = "Prediction"), gr.Image(type="pil", label="Iris")],
)
predict.launch()