Spaces:
Runtime error
Runtime error
import pandas as pd | |
import pmdarima as pm | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
import io | |
import warnings | |
warnings.simplefilter("ignore") | |
def predict_timeseries(data_file): | |
# Load CSV file into a pandas DataFrame | |
data = pd.read_csv(data_file.name, index_col=[0], parse_dates=True) | |
# Convert date column to datetime object | |
data.index = pd.to_datetime(data.index) | |
# Fit the auto ARIMA model | |
model = pm.auto_arima(data.values, seasonal=True, m=12) | |
# Get ARIMA order | |
arima_order = model.order | |
# Plot the actual data | |
# Plot the actual data | |
fig_actual, ax_actual = plt.subplots() | |
ax_actual.plot(data, label = data.columns[-1]) | |
ax_actual.set_xlabel(data.index.name) | |
plt.legend() | |
ax_actual.set_ylabel(data.columns[-1]) | |
ax_actual.set_title("Plot of Actual data for {}".format(data.columns[-1])) | |
plt.show() | |
# Get the last date in the actual data | |
last_date = data.index[-1] | |
# Make predictions | |
predicted_values = model.predict(n_periods=12) | |
# Generate a range of dates starting from the start date | |
pred_index = pd.date_range(start=last_date, periods=len(predicted_values)+1, freq="MS")[1:] | |
# Create a new dataframe with the predicted values and the generated dates | |
predictions = pd.DataFrame({'predicted_values': predicted_values}, index=pred_index) | |
predictions.columns = data.columns | |
predictions.index.name = data.index.name | |
predictions.index.freq = data.index.freq | |
# Merge the dataframes using the index | |
merged_data = pd.concat([data, predictions], axis=0) | |
num_actual = len(data.index) | |
# Plot the actual vs predicted data | |
actual_data = merged_data.iloc[:num_actual,:] | |
fig, ax = plt.subplots() | |
ax.plot(actual_data.index, actual_data[data.columns[-1]], label='Actual') | |
# Plot the predicted data | |
predicted_data = merged_data.iloc[num_actual:,:] | |
ax.plot(predicted_data.index, predicted_data[data.columns[-1]], label='Predicted') | |
# Add x and y axis labels | |
ax.set_xlabel(data.index.name) | |
ax.set_ylabel(data.columns[-1]) | |
# Add title and legend | |
ax.set_title('Plot of Actual and Predicted Values') | |
ax.legend() | |
plt.show() | |
return data.head(), fig_actual, arima_order, predictions, fig | |
input_data = gr.inputs.File(label="Upload CSV file") | |
outputs = [gr.outputs.Dataframe(type = "pandas", label = "FIRST FIVE ROWS OF DATASET"), | |
'plot', | |
gr.outputs.Textbox(label = "ARIMA ORDER"), | |
gr.outputs.Dataframe(type = "pandas", label = "PREDICTIONS FOR NEXT 12 PERIODS"), | |
'plot' | |
] | |
interface = gr.Interface( | |
fn=predict_timeseries, inputs=input_data, outputs=outputs, | |
title="Time series Forecast using AUTO ARIMA", | |
description="Upload a CSV file of monthly time series data to generate 12 period forecasts using ARIMA.", | |
theme = 'darkhuggingface', | |
examples = ["Electric_Production.csv"], | |
live = False, | |
cache_examples = False | |
) | |
interface.launch() |