Prathamesh1420 commited on
Commit
385cdd3
·
verified ·
1 Parent(s): 76fdcc7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -13
app.py CHANGED
@@ -55,24 +55,36 @@ def train_arima(series, order=(5,1,0)):
55
  return model_fit, forecast
56
 
57
  # Create plot
58
- def create_plot(historical, forecast):
59
- plt.figure(figsize=(12, 6))
60
- plt.plot(historical.index, historical, label='Historical')
61
- plt.plot(range(len(historical), len(historical) + len(forecast)), forecast,
62
- label='Forecast', color='orange')
63
- plt.legend()
64
- plt.title('Time Series Forecast')
65
- plt.xlabel('Time Period')
66
- plt.ylabel('Value')
 
 
 
 
 
 
 
 
 
67
 
68
  # Convert plot to base64 for Gradio
69
  buf = io.BytesIO()
70
- plt.savefig(buf, format='png')
71
  buf.seek(0)
72
  img_str = base64.b64encode(buf.read()).decode('utf-8')
73
  buf.close()
 
74
 
75
- return f'<img src="data:image/png;base64,{img_str}" />'
 
 
76
 
77
  # Main prediction function
78
  def predict(part_number, model_name):
@@ -84,11 +96,16 @@ def predict(part_number, model_name):
84
  date_range = pd.date_range(start=start_date, periods=len(df_part), freq='W')
85
  df_part['Date'] = date_range
86
  df_part.set_index('Date', inplace=True)
87
-
88
  series = df_part['y'].astype(float)
89
-
 
 
 
90
  if model_name == 'ARIMA':
91
  model, forecast = train_arima(series)
 
 
 
92
 
93
  # Calculate metrics
94
  train_size = int(len(series) * 0.8)
 
55
  return model_fit, forecast
56
 
57
  # Create plot
58
+ # Create plot
59
+ def create_plot(historical, forecast, freq='M'):
60
+ plt.figure(figsize=(14, 7)) # bigger figure
61
+ plt.plot(historical.index, historical, label='Historical', linewidth=2)
62
+
63
+ # Generate forecast index
64
+ forecast_index = pd.date_range(
65
+ start=historical.index[-1] + pd.tseries.frequencies.to_offset(freq),
66
+ periods=len(forecast),
67
+ freq=freq
68
+ )
69
+
70
+ plt.plot(forecast_index, forecast, label='Forecast', color='orange', linewidth=2)
71
+ plt.legend(fontsize=12)
72
+ plt.title('Time Series Forecast', fontsize=16)
73
+ plt.xlabel('Time Period', fontsize=14)
74
+ plt.ylabel('Value', fontsize=14)
75
+ plt.grid(True, alpha=0.3)
76
 
77
  # Convert plot to base64 for Gradio
78
  buf = io.BytesIO()
79
+ plt.savefig(buf, format='png', bbox_inches="tight")
80
  buf.seek(0)
81
  img_str = base64.b64encode(buf.read()).decode('utf-8')
82
  buf.close()
83
+ plt.close()
84
 
85
+ # Force full width in Gradio
86
+ return f'<img src="data:image/png;base64,{img_str}" style="width:100%; height:auto;" />'
87
+
88
 
89
  # Main prediction function
90
  def predict(part_number, model_name):
 
96
  date_range = pd.date_range(start=start_date, periods=len(df_part), freq='W')
97
  df_part['Date'] = date_range
98
  df_part.set_index('Date', inplace=True)
 
99
  series = df_part['y'].astype(float)
100
+
101
+ # Detect frequency automatically
102
+ freq = pd.infer_freq(series.index) or 'M'
103
+
104
  if model_name == 'ARIMA':
105
  model, forecast = train_arima(series)
106
+
107
+ plot_html = create_plot(series, forecast, freq=freq)
108
+
109
 
110
  # Calculate metrics
111
  train_size = int(len(series) * 0.8)