Spaces:
Runtime error
Runtime error
import gradio as gr | |
import plotly.graph_objects as go | |
import json | |
import requests | |
import os | |
from PIL import Image | |
import hopsworks | |
import joblib | |
import pandas as pd | |
import numpy as np | |
API_KEY = os.getenv("API-KEY-TOMTOM") | |
# Log into hopsworks | |
project = hopsworks.login() | |
fs = project.get_feature_store() | |
mr = project.get_model_registry() | |
model = mr.get_model("stockholm_incidents_model", version=1) | |
model_dir = model.download() | |
model = joblib.load(model_dir + "/stckhlm_inc_model.pkl") | |
print("Model downloaded") | |
def predict(magnitudeOfDelay, hour, iconCategory, latitude, longitude, month): | |
# Create a row from the input | |
row = { | |
'magnitudeOfDelay': magnitudeOfDelay, | |
'hour': hour, | |
'iconCategory': iconCategory, | |
'latitude': latitude, | |
'longitude': longitude, | |
'month': month | |
} | |
# Create a df from the row | |
df_row = pd.DataFrame(row, index=[0]) | |
# change the order to code hour iconCategory latitude longitude magnitudeOfDelay month duration | |
df_row = df_row[['magnitudeOfDelay', 'hour', 'iconCategory', 'latitude', 'longitude', 'month']] | |
# make the features lower case | |
df_row.columns = df_row.columns.str.lower() | |
df_row.columns = df_row.columns.str.replace(' ', '_') | |
# Get the prediction | |
prediction = model.predict(df_row)[0] | |
# Remove the log transformation | |
prediction = prediction[0] | |
return prediction | |
demo = gr.Interface( | |
fn =predict, | |
title="Stockholm Incident Prediction", | |
description="Predicts the duration of a traffic incident in Stockholm", | |
allow_flagging="never", | |
inputs=[ | |
gr.inputs.Slider(0, 60, label="Magnitude of Delay"), | |
gr.inputs.Slider(0, 23, label="Hour"), | |
gr.inputs.Radio(["Accident", "Construction", "Congestion", "Disabled Vehicle", "Mass Transit", "Miscellaneous", "Other News", "Planned Event", "Road Hazard", "Roadwork", "Traffic Flow", "Weather"], label="Icon Category"), | |
gr.inputs.Slider(59.25, 59.40, label="Latitude"), | |
gr.inputs.Slider(18.00, 18.16, label="Longitude"), | |
gr.inputs.Slider(1, 12, label="Month") | |
], | |
outputs=[ | |
gr.outputs.Textbox(label="Duration") | |
]) | |
demo.launch() |