Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -55,24 +55,36 @@ def train_arima(series, order=(5,1,0)):
|
|
55 |
return model_fit, forecast
|
56 |
|
57 |
# Create plot
|
58 |
-
|
59 |
-
|
60 |
-
plt.
|
61 |
-
plt.plot(
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
68 |
# Convert plot to base64 for Gradio
|
69 |
buf = io.BytesIO()
|
70 |
-
plt.savefig(buf, format='png')
|
71 |
buf.seek(0)
|
72 |
img_str = base64.b64encode(buf.read()).decode('utf-8')
|
73 |
buf.close()
|
|
|
74 |
|
75 |
-
|
|
|
|
|
76 |
|
77 |
# Main prediction function
|
78 |
def predict(part_number, model_name):
|
@@ -84,11 +96,16 @@ def predict(part_number, model_name):
|
|
84 |
date_range = pd.date_range(start=start_date, periods=len(df_part), freq='W')
|
85 |
df_part['Date'] = date_range
|
86 |
df_part.set_index('Date', inplace=True)
|
87 |
-
|
88 |
series = df_part['y'].astype(float)
|
89 |
-
|
|
|
|
|
|
|
90 |
if model_name == 'ARIMA':
|
91 |
model, forecast = train_arima(series)
|
|
|
|
|
|
|
92 |
|
93 |
# Calculate metrics
|
94 |
train_size = int(len(series) * 0.8)
|
|
|
55 |
return model_fit, forecast
|
56 |
|
57 |
# Create plot
|
58 |
+
# Create plot
|
59 |
+
def create_plot(historical, forecast, freq='M'):
|
60 |
+
plt.figure(figsize=(14, 7)) # bigger figure
|
61 |
+
plt.plot(historical.index, historical, label='Historical', linewidth=2)
|
62 |
+
|
63 |
+
# Generate forecast index
|
64 |
+
forecast_index = pd.date_range(
|
65 |
+
start=historical.index[-1] + pd.tseries.frequencies.to_offset(freq),
|
66 |
+
periods=len(forecast),
|
67 |
+
freq=freq
|
68 |
+
)
|
69 |
+
|
70 |
+
plt.plot(forecast_index, forecast, label='Forecast', color='orange', linewidth=2)
|
71 |
+
plt.legend(fontsize=12)
|
72 |
+
plt.title('Time Series Forecast', fontsize=16)
|
73 |
+
plt.xlabel('Time Period', fontsize=14)
|
74 |
+
plt.ylabel('Value', fontsize=14)
|
75 |
+
plt.grid(True, alpha=0.3)
|
76 |
|
77 |
# Convert plot to base64 for Gradio
|
78 |
buf = io.BytesIO()
|
79 |
+
plt.savefig(buf, format='png', bbox_inches="tight")
|
80 |
buf.seek(0)
|
81 |
img_str = base64.b64encode(buf.read()).decode('utf-8')
|
82 |
buf.close()
|
83 |
+
plt.close()
|
84 |
|
85 |
+
# Force full width in Gradio
|
86 |
+
return f'<img src="data:image/png;base64,{img_str}" style="width:100%; height:auto;" />'
|
87 |
+
|
88 |
|
89 |
# Main prediction function
|
90 |
def predict(part_number, model_name):
|
|
|
96 |
date_range = pd.date_range(start=start_date, periods=len(df_part), freq='W')
|
97 |
df_part['Date'] = date_range
|
98 |
df_part.set_index('Date', inplace=True)
|
|
|
99 |
series = df_part['y'].astype(float)
|
100 |
+
|
101 |
+
# Detect frequency automatically
|
102 |
+
freq = pd.infer_freq(series.index) or 'M'
|
103 |
+
|
104 |
if model_name == 'ARIMA':
|
105 |
model, forecast = train_arima(series)
|
106 |
+
|
107 |
+
plot_html = create_plot(series, forecast, freq=freq)
|
108 |
+
|
109 |
|
110 |
# Calculate metrics
|
111 |
train_size = int(len(series) * 0.8)
|