Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -55,24 +55,36 @@ def train_arima(series, order=(5,1,0)):
|
|
| 55 |
return model_fit, forecast
|
| 56 |
|
| 57 |
# Create plot
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
plt.
|
| 61 |
-
plt.plot(
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
# Convert plot to base64 for Gradio
|
| 69 |
buf = io.BytesIO()
|
| 70 |
-
plt.savefig(buf, format='png')
|
| 71 |
buf.seek(0)
|
| 72 |
img_str = base64.b64encode(buf.read()).decode('utf-8')
|
| 73 |
buf.close()
|
|
|
|
| 74 |
|
| 75 |
-
|
|
|
|
|
|
|
| 76 |
|
| 77 |
# Main prediction function
|
| 78 |
def predict(part_number, model_name):
|
|
@@ -84,11 +96,16 @@ def predict(part_number, model_name):
|
|
| 84 |
date_range = pd.date_range(start=start_date, periods=len(df_part), freq='W')
|
| 85 |
df_part['Date'] = date_range
|
| 86 |
df_part.set_index('Date', inplace=True)
|
| 87 |
-
|
| 88 |
series = df_part['y'].astype(float)
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
| 90 |
if model_name == 'ARIMA':
|
| 91 |
model, forecast = train_arima(series)
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
# Calculate metrics
|
| 94 |
train_size = int(len(series) * 0.8)
|
|
|
|
| 55 |
return model_fit, forecast
|
| 56 |
|
| 57 |
# Create plot
|
| 58 |
+
# Create plot
|
| 59 |
+
def create_plot(historical, forecast, freq='M'):
|
| 60 |
+
plt.figure(figsize=(14, 7)) # bigger figure
|
| 61 |
+
plt.plot(historical.index, historical, label='Historical', linewidth=2)
|
| 62 |
+
|
| 63 |
+
# Generate forecast index
|
| 64 |
+
forecast_index = pd.date_range(
|
| 65 |
+
start=historical.index[-1] + pd.tseries.frequencies.to_offset(freq),
|
| 66 |
+
periods=len(forecast),
|
| 67 |
+
freq=freq
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
plt.plot(forecast_index, forecast, label='Forecast', color='orange', linewidth=2)
|
| 71 |
+
plt.legend(fontsize=12)
|
| 72 |
+
plt.title('Time Series Forecast', fontsize=16)
|
| 73 |
+
plt.xlabel('Time Period', fontsize=14)
|
| 74 |
+
plt.ylabel('Value', fontsize=14)
|
| 75 |
+
plt.grid(True, alpha=0.3)
|
| 76 |
|
| 77 |
# Convert plot to base64 for Gradio
|
| 78 |
buf = io.BytesIO()
|
| 79 |
+
plt.savefig(buf, format='png', bbox_inches="tight")
|
| 80 |
buf.seek(0)
|
| 81 |
img_str = base64.b64encode(buf.read()).decode('utf-8')
|
| 82 |
buf.close()
|
| 83 |
+
plt.close()
|
| 84 |
|
| 85 |
+
# Force full width in Gradio
|
| 86 |
+
return f'<img src="data:image/png;base64,{img_str}" style="width:100%; height:auto;" />'
|
| 87 |
+
|
| 88 |
|
| 89 |
# Main prediction function
|
| 90 |
def predict(part_number, model_name):
|
|
|
|
| 96 |
date_range = pd.date_range(start=start_date, periods=len(df_part), freq='W')
|
| 97 |
df_part['Date'] = date_range
|
| 98 |
df_part.set_index('Date', inplace=True)
|
|
|
|
| 99 |
series = df_part['y'].astype(float)
|
| 100 |
+
|
| 101 |
+
# Detect frequency automatically
|
| 102 |
+
freq = pd.infer_freq(series.index) or 'M'
|
| 103 |
+
|
| 104 |
if model_name == 'ARIMA':
|
| 105 |
model, forecast = train_arima(series)
|
| 106 |
+
|
| 107 |
+
plot_html = create_plot(series, forecast, freq=freq)
|
| 108 |
+
|
| 109 |
|
| 110 |
# Calculate metrics
|
| 111 |
train_size = int(len(series) * 0.8)
|