ab_challenge / get_forecast.py
stinoco's picture
changed labels
111b9fa
import pandas as pd
from dateutil.relativedelta import relativedelta
import matplotlib.pyplot as plt
from prophet.serialize import model_from_json
def get_forecast(serie: str, periods, percent_change: int):
if serie == 'SOM':
model_file, data_file = 'som_model.json', 'som_data.csv'
elif serie == 'Volumen':
model_file, data_file = 'volumen_model.json', 'volumen_data.csv'
else:
raise ValueError('Input not valid')
# load model files
with open('models/' + model_file, 'r') as fin:
model = model_from_json(fin.read())
history = pd.read_csv('data/' + data_file, encoding = 'utf-8-sig', sep = ';')
history['ds'] = pd.to_datetime(history['ds'])
# make future dataframe
future = model.make_future_dataframe(periods = periods, freq = 'MS')
future = future.tail(periods)
# filter prices last year
last_year_dates = future.ds.apply(lambda x: x - relativedelta(years=1))
past_prices = history[history['ds'].isin(last_year_dates)]['precio_hl'].values
# generate new_prices
percent_change = percent_change / 100
new_prices = past_prices * (1 + percent_change)
future['precio_hl'] = new_prices
# prediction
forecast = model.predict(future)
future_values = forecast[['ds', 'yhat']]
# aux to plot
last_obs = history.iloc[-1:][['ds', 'y']].rename(columns = {'y': 'yhat'})
future_aux = pd.concat([last_obs, future_values])
# 0 price change scenario
future_0 = future.copy()
future_0['precio_hl'] = past_prices
forecast_0 = model.predict(future_0)
values_0 = forecast_0[['ds', 'yhat']]
aux_0 = pd.concat([last_obs, values_0])
# arrange dataframe
df_future = future_values.rename(columns = {'ds': 'Date', 'yhat': f'{serie} changing prices'}).copy()
df_future['Date'] = df_future['Date'].apply(lambda x: x.date())
df_future[f'{serie} holding prices'] = values_0.rename(columns = {'ds': 'Date', 'yhat': serie})[serie]
df_future['Diff'] = df_future[f'{serie} changing prices'] - df_future[f'{serie} holding prices']
# round values to 4 decimals
df_future[f'{serie} changing prices'] = df_future[f'{serie} changing prices'].apply(lambda x: round(x, 4))
df_future[f'{serie} holding prices'] = df_future[f'{serie} holding prices'].apply(lambda x: round(x, 4))
df_future['Diff'] = df_future['Diff'].apply(lambda x: round(x, 4))
# plot
fig = plt.figure()
plt.plot(future_aux['ds'], future_aux['yhat'], label = 'Price Policy Change Forecast', marker = '.', color = 'C1')
plt.plot(aux_0['ds'], aux_0['yhat'], label = 'No Policy Change Forecast', marker = '.', color = 'C2')
plt.plot(history['ds'], history['y'], label = 'Historic data', marker = '.', color = 'C0')
plt.xticks(rotation = 45)
plt.ylabel(serie)
plt.tight_layout()
plt.legend()
return fig, df_future