Stockholm_Test / app.py
SevenhuijsenM
Added gui
3829ec7
raw
history blame
2.21 kB
import gradio as gr
import plotly.graph_objects as go
import json
import requests
import os
from PIL import Image
import hopsworks
import joblib
import pandas as pd
import numpy as np
API_KEY = os.getenv("API-KEY-TOMTOM")
# Log into hopsworks
project = hopsworks.login()
fs = project.get_feature_store()
mr = project.get_model_registry()
model = mr.get_model("stockholm_incidents_model", version=1)
model_dir = model.download()
model = joblib.load(model_dir + "/stckhlm_inc_model.pkl")
print("Model downloaded")
def predict(magnitudeOfDelay, hour, iconCategory, latitude, longitude, month):
# Create a row from the input
row = {
'magnitudeOfDelay': magnitudeOfDelay,
'hour': hour,
'iconCategory': iconCategory,
'latitude': latitude,
'longitude': longitude,
'month': month
}
# Create a df from the row
df_row = pd.DataFrame(row, index=[0])
# change the order to code hour iconCategory latitude longitude magnitudeOfDelay month duration
df_row = df_row[['magnitudeOfDelay', 'hour', 'iconCategory', 'latitude', 'longitude', 'month']]
# make the features lower case
df_row.columns = df_row.columns.str.lower()
df_row.columns = df_row.columns.str.replace(' ', '_')
# Get the prediction
prediction = model.predict(df_row)[0]
# Remove the log transformation
prediction = prediction[0]
return prediction
demo = gr.Interface(
fn =predict,
title="Stockholm Incident Prediction",
description="Predicts the duration of a traffic incident in Stockholm",
allow_flagging="never",
inputs=[
gr.inputs.Slider(0, 60, label="Magnitude of Delay"),
gr.inputs.Slider(0, 23, label="Hour"),
gr.inputs.Radio(["Accident", "Construction", "Congestion", "Disabled Vehicle", "Mass Transit", "Miscellaneous", "Other News", "Planned Event", "Road Hazard", "Roadwork", "Traffic Flow", "Weather"], label="Icon Category"),
gr.inputs.Slider(59.25, 59.40, label="Latitude"),
gr.inputs.Slider(18.00, 18.16, label="Longitude"),
gr.inputs.Slider(1, 12, label="Month")
],
outputs=[
gr.outputs.Textbox(label="Duration")
])
demo.launch()