tbdavid2019 commited on
Commit
1acfd2b
·
1 Parent(s): ca21317

回到最初

Browse files
Files changed (1) hide show
  1. app.py +61 -70
app.py CHANGED
@@ -16,7 +16,6 @@ import os
16
  import yfinance as yf
17
  import logging
18
  from datetime import datetime, timedelta
19
- from prophet import Prophet
20
 
21
  # 設置日誌
22
  logging.basicConfig(level=logging.INFO,
@@ -27,9 +26,11 @@ def setup_font():
27
  try:
28
  url_font = "https://drive.google.com/uc?id=1eGAsTN1HBpJAkeVM57_C7ccp7hbgSz3_"
29
  response_font = requests.get(url_font)
 
30
  with tempfile.NamedTemporaryFile(delete=False, suffix='.ttf') as tmp_file:
31
  tmp_file.write(response_font.content)
32
  tmp_file_path = tmp_file.name
 
33
  fm.fontManager.addfont(tmp_file_path)
34
  mpl.rc('font', family='Taipei Sans TC Beta')
35
  except Exception as e:
@@ -51,25 +52,30 @@ def fetch_stock_categories():
51
  url = "https://tw.stock.yahoo.com/class/"
52
  response = requests.get(url, headers=headers, timeout=10)
53
  response.raise_for_status()
 
54
  soup = BeautifulSoup(response.text, 'html.parser')
55
  main_categories = soup.find_all('div', class_='C($c-link-text)')
 
56
  data = []
57
  for category in main_categories:
58
  main_category_name = category.find('h2', class_="Fw(b) Fz(24px) Lh(32px)")
59
  if main_category_name:
60
  main_category_name = main_category_name.text.strip()
61
  sub_categories = category.find_all('a', class_='Fz(16px) Lh(1.5) C($c-link-text) C($c-active-text):h Fw(b):h Td(n)')
 
62
  for sub_category in sub_categories:
63
  data.append({
64
  '台股': main_category_name,
65
  '類股': sub_category.text.strip(),
66
  '網址': "https://tw.stock.yahoo.com" + sub_category['href']
67
  })
 
68
  category_dict = {}
69
  for item in data:
70
  if item['台股'] not in category_dict:
71
  category_dict[item['台股']] = []
72
  category_dict[item['台股']].append({'類股': item['類股'], '網址': item['網址']})
 
73
  return category_dict
74
  except Exception as e:
75
  logging.error(f"獲取股票類別失敗: {str(e)}")
@@ -78,16 +84,17 @@ def fetch_stock_categories():
78
  # 股票預測模型類別
79
  class StockPredictor:
80
  def __init__(self):
81
- self.lstm_model = None
82
- self.prophet_model = None
83
  self.scaler = MinMaxScaler()
84
-
85
  def prepare_data(self, df, selected_features):
86
  scaled_data = self.scaler.fit_transform(df[selected_features])
 
87
  X, y = [], []
88
  for i in range(len(scaled_data) - 1):
89
  X.append(scaled_data[i])
90
  y.append(scaled_data[i+1])
 
91
  return np.array(X).reshape(-1, 1, len(selected_features)), np.array(y)
92
 
93
  def build_model(self, input_shape):
@@ -103,8 +110,8 @@ class StockPredictor:
103
 
104
  def train(self, df, selected_features):
105
  X, y = self.prepare_data(df, selected_features)
106
- self.lstm_model = self.build_model((1, X.shape[2]))
107
- history = self.lstm_model.fit(
108
  X, y,
109
  epochs=50,
110
  batch_size=32,
@@ -116,18 +123,18 @@ class StockPredictor:
116
  def predict(self, last_data, n_days):
117
  predictions = []
118
  current_data = last_data.copy()
 
119
  for _ in range(n_days):
120
- next_day = self.lstm_model.predict(current_data.reshape(1, 1, -1), verbose=0)
121
  predictions.append(next_day[0])
 
122
  current_data = current_data.flatten()
123
  current_data[:len(next_day[0])] = next_day[0]
124
  current_data = current_data.reshape(1, -1)
 
125
  return np.array(predictions)
126
-
127
- def train_prophet(self, df_prophet):
128
- self.prophet_model = Prophet()
129
- self.prophet_model.fit(df_prophet)
130
 
 
131
  def update_stocks(category):
132
  if not category or category not in category_dict:
133
  return []
@@ -137,8 +144,10 @@ def get_stock_items(url):
137
  try:
138
  response = requests.get(url, headers=headers, timeout=10)
139
  response.raise_for_status()
 
140
  soup = BeautifulSoup(response.text, 'html.parser')
141
  stock_items = soup.find_all('li', class_='List(n)')
 
142
  stocks_dict = {}
143
  for item in stock_items:
144
  stock_name = item.find('div', class_='Lh(20px) Fw(600) Fz(16px) Ell')
@@ -148,6 +157,7 @@ def get_stock_items(url):
148
  display_code = full_code.split('.')[0]
149
  display_name = f"{stock_name.text.strip()}{display_code}"
150
  stocks_dict[display_name] = full_code
 
151
  return stocks_dict
152
  except Exception as e:
153
  logging.error(f"獲取股票項目失敗: {str(e)}")
@@ -169,8 +179,10 @@ def update_stock(category, stock):
169
  stock_plot: gr.update(value=None),
170
  status_output: gr.update(value="")
171
  }
 
172
  url = next((item['網址'] for item in category_dict.get(category, [])
173
  if item['類股'] == stock), None)
 
174
  if url:
175
  stock_items = get_stock_items(url)
176
  return {
@@ -184,84 +196,65 @@ def update_stock(category, stock):
184
  status_output: gr.update(value="")
185
  }
186
 
187
- def predict_stock(category, stock, stock_item, period, selected_features, model_type):
188
  if not all([category, stock, stock_item]):
189
  return gr.update(value=None), "請選擇產業類別、類股和股票"
 
190
  try:
191
  url = next((item['網址'] for item in category_dict.get(category, [])
192
- if item['類股'] == stock), None)
193
  if not url:
194
  return gr.update(value=None), "無法獲取類股網址"
 
195
  stock_items = get_stock_items(url)
196
  stock_code = stock_items.get(stock_item, "")
 
197
  if not stock_code:
198
  return gr.update(value=None), "無法獲取股票代碼"
199
 
200
- # 下載股票數據
201
  df = yf.download(stock_code, period=period)
202
  if df.empty:
203
  raise ValueError("無法獲取股票數據")
204
 
 
205
  predictor = StockPredictor()
 
 
 
 
 
 
 
 
 
 
 
206
 
207
- if model_type == "LSTM":
208
- predictor.train(df, selected_features)
209
- last_data = predictor.scaler.transform(df[selected_features].iloc[-1:].values)
210
- predictions = predictor.predict(last_data[0], 5)
211
- last_original = df[selected_features].iloc[-1].values
212
- predictions_original = predictor.scaler.inverse_transform(
213
- np.vstack([last_data, predictions])
214
- )
215
- all_predictions = np.vstack([last_original, predictions_original[1:]])
216
-
217
- elif model_type == "Prophet":
218
- target_feature = selected_features[0] # 使用第一個選擇的特徵
219
- df_prophet = df.reset_index()
220
- df_prophet = df_prophet[['Date', target_feature]].rename(
221
- columns={'Date': 'ds', target_feature: 'y'})
222
-
223
- predictor.train_prophet(df_prophet)
224
- future_dates = pd.date_range(
225
- start=df_prophet['ds'].iloc[-1] + pd.Timedelta(days=1),
226
- periods=5,
227
- freq='D'
228
- )
229
- future = pd.DataFrame({'ds': future_dates})
230
- forecast = predictor.prophet_model.predict(future)
231
- all_predictions = forecast['yhat'].values
232
-
233
  # 創建日期索引
234
  dates = [datetime.now() + timedelta(days=i) for i in range(6)]
235
  date_labels = [d.strftime('%m/%d') for d in dates]
236
 
237
  # 繪圖
238
  fig, ax = plt.subplots(figsize=(14, 7))
 
 
239
 
240
- if model_type == "LSTM":
241
- colors = ['#FF9999', '#66B2FF']
242
- labels = [f'預測{feature}' for feature in selected_features]
243
- for i, (label, color) in enumerate(zip(labels, colors)):
244
- ax.plot(date_labels, all_predictions[:, i], label=label,
245
- marker='o', color=color, linewidth=2)
246
- for j, value in enumerate(all_predictions[:, i]):
247
- ax.annotate(f'{value:.2f}', (date_labels[j], value),
248
- textcoords="offset points", xytext=(0,10),
249
- ha='center', va='bottom')
250
- elif model_type == "Prophet":
251
- ax.plot(date_labels[1:], all_predictions, label=f'預測{target_feature}',
252
- marker='o', color='#FF9999', linewidth=2)
253
- for j, value in enumerate(all_predictions):
254
- ax.annotate(f'{value:.2f}', (date_labels[j+1], value),
255
- textcoords="offset points", xytext=(0,10),
256
- ha='center', va='bottom')
257
 
258
  ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
259
  ax.set_xlabel('日期', labelpad=10)
260
  ax.set_ylabel('股價', labelpad=10)
261
  ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
262
  ax.grid(True, linestyle='--', alpha=0.7)
263
- plt.tight_layout()
264
 
 
265
  return gr.update(value=fig), "預測成功"
266
 
267
  except Exception as e:
@@ -303,33 +296,31 @@ with gr.Blocks() as demo:
303
  label="選擇要用於預測的特徵",
304
  value=['Open', 'Close']
305
  )
306
- model_type_radio = gr.Radio(
307
- choices=["LSTM", "Prophet"],
308
- label="選擇模型類型",
309
- value="LSTM"
310
- )
311
  predict_button = gr.Button("開始預測", variant="primary")
312
  status_output = gr.Textbox(label="狀態", interactive=False)
313
- with gr.Column():
314
- stock_plot = gr.Plot(label="股價預測圖")
315
-
 
316
  # 事件綁定
317
  category_dropdown.change(
318
  update_category,
319
  inputs=[category_dropdown],
320
  outputs=[stock_dropdown, stock_item_dropdown, stock_plot, status_output]
321
  )
 
322
  stock_dropdown.change(
323
  update_stock,
324
  inputs=[category_dropdown, stock_dropdown],
325
  outputs=[stock_item_dropdown, stock_plot, status_output]
326
  )
 
327
  predict_button.click(
328
  predict_stock,
329
- inputs=[category_dropdown, stock_dropdown, stock_item_dropdown,
330
- period_dropdown, features_checkbox, model_type_radio],
331
  outputs=[stock_plot, status_output]
332
  )
333
-
 
334
  if __name__ == "__main__":
335
- demo.launch(share=False)
 
16
  import yfinance as yf
17
  import logging
18
  from datetime import datetime, timedelta
 
19
 
20
  # 設置日誌
21
  logging.basicConfig(level=logging.INFO,
 
26
  try:
27
  url_font = "https://drive.google.com/uc?id=1eGAsTN1HBpJAkeVM57_C7ccp7hbgSz3_"
28
  response_font = requests.get(url_font)
29
+
30
  with tempfile.NamedTemporaryFile(delete=False, suffix='.ttf') as tmp_file:
31
  tmp_file.write(response_font.content)
32
  tmp_file_path = tmp_file.name
33
+
34
  fm.fontManager.addfont(tmp_file_path)
35
  mpl.rc('font', family='Taipei Sans TC Beta')
36
  except Exception as e:
 
52
  url = "https://tw.stock.yahoo.com/class/"
53
  response = requests.get(url, headers=headers, timeout=10)
54
  response.raise_for_status()
55
+
56
  soup = BeautifulSoup(response.text, 'html.parser')
57
  main_categories = soup.find_all('div', class_='C($c-link-text)')
58
+
59
  data = []
60
  for category in main_categories:
61
  main_category_name = category.find('h2', class_="Fw(b) Fz(24px) Lh(32px)")
62
  if main_category_name:
63
  main_category_name = main_category_name.text.strip()
64
  sub_categories = category.find_all('a', class_='Fz(16px) Lh(1.5) C($c-link-text) C($c-active-text):h Fw(b):h Td(n)')
65
+
66
  for sub_category in sub_categories:
67
  data.append({
68
  '台股': main_category_name,
69
  '類股': sub_category.text.strip(),
70
  '網址': "https://tw.stock.yahoo.com" + sub_category['href']
71
  })
72
+
73
  category_dict = {}
74
  for item in data:
75
  if item['台股'] not in category_dict:
76
  category_dict[item['台股']] = []
77
  category_dict[item['台股']].append({'類股': item['類股'], '網址': item['網址']})
78
+
79
  return category_dict
80
  except Exception as e:
81
  logging.error(f"獲取股票類別失敗: {str(e)}")
 
84
  # 股票預測模型類別
85
  class StockPredictor:
86
  def __init__(self):
87
+ self.model = None
 
88
  self.scaler = MinMaxScaler()
89
+
90
  def prepare_data(self, df, selected_features):
91
  scaled_data = self.scaler.fit_transform(df[selected_features])
92
+
93
  X, y = [], []
94
  for i in range(len(scaled_data) - 1):
95
  X.append(scaled_data[i])
96
  y.append(scaled_data[i+1])
97
+
98
  return np.array(X).reshape(-1, 1, len(selected_features)), np.array(y)
99
 
100
  def build_model(self, input_shape):
 
110
 
111
  def train(self, df, selected_features):
112
  X, y = self.prepare_data(df, selected_features)
113
+ self.model = self.build_model((1, X.shape[2]))
114
+ history = self.model.fit(
115
  X, y,
116
  epochs=50,
117
  batch_size=32,
 
123
  def predict(self, last_data, n_days):
124
  predictions = []
125
  current_data = last_data.copy()
126
+
127
  for _ in range(n_days):
128
+ next_day = self.model.predict(current_data.reshape(1, 1, -1), verbose=0)
129
  predictions.append(next_day[0])
130
+
131
  current_data = current_data.flatten()
132
  current_data[:len(next_day[0])] = next_day[0]
133
  current_data = current_data.reshape(1, -1)
134
+
135
  return np.array(predictions)
 
 
 
 
136
 
137
+ # Gradio界面函數
138
  def update_stocks(category):
139
  if not category or category not in category_dict:
140
  return []
 
144
  try:
145
  response = requests.get(url, headers=headers, timeout=10)
146
  response.raise_for_status()
147
+
148
  soup = BeautifulSoup(response.text, 'html.parser')
149
  stock_items = soup.find_all('li', class_='List(n)')
150
+
151
  stocks_dict = {}
152
  for item in stock_items:
153
  stock_name = item.find('div', class_='Lh(20px) Fw(600) Fz(16px) Ell')
 
157
  display_code = full_code.split('.')[0]
158
  display_name = f"{stock_name.text.strip()}{display_code}"
159
  stocks_dict[display_name] = full_code
160
+
161
  return stocks_dict
162
  except Exception as e:
163
  logging.error(f"獲取股票項目失敗: {str(e)}")
 
179
  stock_plot: gr.update(value=None),
180
  status_output: gr.update(value="")
181
  }
182
+
183
  url = next((item['網址'] for item in category_dict.get(category, [])
184
  if item['類股'] == stock), None)
185
+
186
  if url:
187
  stock_items = get_stock_items(url)
188
  return {
 
196
  status_output: gr.update(value="")
197
  }
198
 
199
+ def predict_stock(category, stock, stock_item, period, selected_features):
200
  if not all([category, stock, stock_item]):
201
  return gr.update(value=None), "請選擇產業類別、類股和股票"
202
+
203
  try:
204
  url = next((item['網址'] for item in category_dict.get(category, [])
205
+ if item['類股'] == stock), None)
206
  if not url:
207
  return gr.update(value=None), "無法獲取類股網址"
208
+
209
  stock_items = get_stock_items(url)
210
  stock_code = stock_items.get(stock_item, "")
211
+
212
  if not stock_code:
213
  return gr.update(value=None), "無法獲取股票代碼"
214
 
215
+ # 下載股票數據,根據用戶選擇的時間範圍
216
  df = yf.download(stock_code, period=period)
217
  if df.empty:
218
  raise ValueError("無法獲取股票數據")
219
 
220
+ # 預測
221
  predictor = StockPredictor()
222
+ predictor.train(df, selected_features)
223
+
224
+ last_data = predictor.scaler.transform(df[selected_features].iloc[-1:].values)
225
+ predictions = predictor.predict(last_data[0], 5)
226
+
227
+ # 反轉預測結果
228
+ last_original = df[selected_features].iloc[-1].values
229
+ predictions_original = predictor.scaler.inverse_transform(
230
+ np.vstack([last_data, predictions])
231
+ )
232
+ all_predictions = np.vstack([last_original, predictions_original[1:]])
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  # 創建日期索引
235
  dates = [datetime.now() + timedelta(days=i) for i in range(6)]
236
  date_labels = [d.strftime('%m/%d') for d in dates]
237
 
238
  # 繪圖
239
  fig, ax = plt.subplots(figsize=(14, 7))
240
+ colors = ['#FF9999', '#66B2FF']
241
+ labels = [f'預測{feature}' for feature in selected_features]
242
 
243
+ for i, (label, color) in enumerate(zip(labels, colors)):
244
+ ax.plot(date_labels, all_predictions[:, i], label=label,
245
+ marker='o', color=color, linewidth=2)
246
+ for j, value in enumerate(all_predictions[:, i]):
247
+ ax.annotate(f'{value:.2f}', (date_labels[j], value),
248
+ textcoords="offset points", xytext=(0,10),
249
+ ha='center', va='bottom')
 
 
 
 
 
 
 
 
 
 
250
 
251
  ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14)
252
  ax.set_xlabel('日期', labelpad=10)
253
  ax.set_ylabel('股價', labelpad=10)
254
  ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
255
  ax.grid(True, linestyle='--', alpha=0.7)
 
256
 
257
+ plt.tight_layout()
258
  return gr.update(value=fig), "預測成功"
259
 
260
  except Exception as e:
 
296
  label="選擇要用於預測的特徵",
297
  value=['Open', 'Close']
298
  )
 
 
 
 
 
299
  predict_button = gr.Button("開始預測", variant="primary")
300
  status_output = gr.Textbox(label="狀態", interactive=False)
301
+
302
+ with gr.Row():
303
+ stock_plot = gr.Plot(label="股價預測圖")
304
+
305
  # 事件綁定
306
  category_dropdown.change(
307
  update_category,
308
  inputs=[category_dropdown],
309
  outputs=[stock_dropdown, stock_item_dropdown, stock_plot, status_output]
310
  )
311
+
312
  stock_dropdown.change(
313
  update_stock,
314
  inputs=[category_dropdown, stock_dropdown],
315
  outputs=[stock_item_dropdown, stock_plot, status_output]
316
  )
317
+
318
  predict_button.click(
319
  predict_stock,
320
+ inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, period_dropdown, features_checkbox],
 
321
  outputs=[stock_plot, status_output]
322
  )
323
+
324
+ # 啟動應用
325
  if __name__ == "__main__":
326
+ demo.launch(share=False)