Tonic commited on
Commit
7f41c98
·
verified ·
1 Parent(s): 45d4dac

initial stock predictions testing

Browse files
Files changed (3) hide show
  1. README.md +62 -0
  2. app.py +213 -0
  3. requirements.txt +119 -0
README.md CHANGED
@@ -12,3 +12,65 @@ short_description: Use Amazon Chronos To Predict Stock Prices
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
15
+
16
+
17
+ # Stock Price Prediction with Amazon Chronos
18
+
19
+ A neural network application that uses Amazon's Chronos model for time series forecasting to predict stock prices.
20
+
21
+ ## Features
22
+
23
+ - Real-time stock price predictions using Amazon Chronos
24
+ - Interactive visualization of predictions with confidence intervals
25
+ - Support for multiple timeframes (daily, hourly, 15-minute)
26
+ - User-friendly Gradio interface
27
+ - Free stock data using yfinance API
28
+
29
+ ## Hugging Face Spaces Deployment
30
+
31
+ This application is configured to run on Hugging Face Spaces. To deploy:
32
+
33
+ 1. Create a new Space on Hugging Face
34
+ 2. Choose "Docker" as the SDK
35
+ 3. Upload all the files to your Space
36
+
37
+ ## Local Development
38
+
39
+ To run locally:
40
+
41
+ ```bash
42
+ # Create and activate virtual environment
43
+ python -m venv .venv
44
+ source .venv/bin/activate # On Windows: .venv\Scripts\activate
45
+
46
+ # Install dependencies
47
+ pip install -r requirements.txt
48
+
49
+ # Run the application
50
+ python app.py
51
+ ```
52
+
53
+ ## Model Details
54
+
55
+ The application uses Amazon's Chronos model for time series forecasting. The model is configured to:
56
+
57
+ - Make predictions for stock prices
58
+ - Calculate confidence intervals
59
+ - Generate interactive visualizations
60
+ - Support multiple prediction horizons
61
+
62
+ ## Usage
63
+
64
+ 1. Enter a stock symbol (e.g., AAPL, GOOGL, MSFT)
65
+ 2. Select the desired timeframe (1d, 1h, 15m)
66
+ 3. Choose the number of days to predict (1-30)
67
+ 4. Click "Make Prediction" to generate forecasts
68
+
69
+ The application will display:
70
+ - A plot showing historical prices and predictions
71
+ - Confidence intervals for the predictions
72
+ - A separate plot showing prediction uncertainty
73
+
74
+ ## License
75
+
76
+ This project is licensed under the MIT License - see the LICENSE file for details.
app.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ import numpy as np
4
+ from datetime import datetime, timedelta
5
+ import yfinance as yf
6
+ import torch
7
+ from chronos import BaseChronosPipeline
8
+ import plotly.graph_objects as go
9
+ from plotly.subplots import make_subplots
10
+
11
+ # Initialize Chronos pipeline
12
+ pipeline = None
13
+
14
+ def load_pipeline():
15
+ """Load the Chronos model with CPU configuration"""
16
+ global pipeline
17
+ if pipeline is None:
18
+ pipeline = BaseChronosPipeline.from_pretrained(
19
+ "amazon/chronos-bolt-base",
20
+ device_map="cpu", # Force CPU usage
21
+ torch_dtype=torch.float32 # Use float32 for CPU
22
+ )
23
+ pipeline.model = pipeline.model.eval()
24
+ return pipeline
25
+
26
+ def get_historical_data(symbol: str, timeframe: str = "1d") -> np.ndarray:
27
+ """
28
+ Fetch historical data using yfinance.
29
+
30
+ Args:
31
+ symbol (str): The stock symbol (e.g., 'AAPL')
32
+ timeframe (str): The timeframe for data ('1d', '1h', '15m')
33
+
34
+ Returns:
35
+ np.ndarray: Array of historical prices for Chronos model
36
+ """
37
+ try:
38
+ # Map timeframe to yfinance interval
39
+ tf_map = {
40
+ "1d": "1d",
41
+ "1h": "1h",
42
+ "15m": "15m"
43
+ }
44
+ interval = tf_map.get(timeframe, "1d")
45
+
46
+ # Calculate date range
47
+ end_date = datetime.now()
48
+ if timeframe == "1d":
49
+ start_date = end_date - timedelta(days=365) # 1 year of daily data
50
+ elif timeframe == "1h":
51
+ start_date = end_date - timedelta(days=30) # 30 days of hourly data
52
+ else: # 15m
53
+ start_date = end_date - timedelta(days=7) # 7 days of 15-min data
54
+
55
+ # Fetch data using yfinance
56
+ ticker = yf.Ticker(symbol)
57
+ df = ticker.history(start=start_date, end=end_date, interval=interval)
58
+
59
+ # Calculate returns
60
+ df['returns'] = df['Close'].pct_change()
61
+
62
+ # Drop NaN values
63
+ df = df.dropna()
64
+
65
+ # Normalize the data
66
+ returns = df['returns'].values
67
+ normalized_returns = (returns - returns.mean()) / returns.std()
68
+
69
+ # Convert to the format expected by Chronos
70
+ return normalized_returns.reshape(-1, 1)
71
+
72
+ except Exception as e:
73
+ raise Exception(f"Error fetching historical data for {symbol}: {str(e)}")
74
+
75
+ def make_prediction(symbol: str, timeframe: str = "1d", prediction_days: int = 5):
76
+ """
77
+ Make prediction using Chronos model.
78
+
79
+ Args:
80
+ symbol (str): Stock symbol
81
+ timeframe (str): Data timeframe
82
+ prediction_days (int): Number of days to predict
83
+
84
+ Returns:
85
+ dict: Prediction results and visualization
86
+ """
87
+ try:
88
+ # Load pipeline
89
+ pipe = load_pipeline()
90
+
91
+ # Get historical data
92
+ historical_data = get_historical_data(symbol, timeframe)
93
+
94
+ # Convert to tensor
95
+ context = torch.tensor(historical_data, dtype=torch.float32)
96
+
97
+ # Make prediction
98
+ with torch.inference_mode():
99
+ prediction = pipe.predict(
100
+ context=context,
101
+ prediction_length=prediction_days,
102
+ num_samples=100
103
+ ).detach().cpu().numpy()
104
+
105
+ # Get actual historical prices for plotting
106
+ ticker = yf.Ticker(symbol)
107
+ hist_data = ticker.history(period="1mo")
108
+
109
+ # Create prediction dates
110
+ last_date = hist_data.index[-1]
111
+ pred_dates = pd.date_range(start=last_date + timedelta(days=1), periods=prediction_days)
112
+
113
+ # Calculate prediction statistics
114
+ mean_pred = prediction.mean(axis=0)
115
+ std_pred = prediction.std(axis=0)
116
+
117
+ # Create visualization
118
+ fig = make_subplots(rows=2, cols=1, shared_xaxes=True,
119
+ vertical_spacing=0.03, subplot_titles=('Price Prediction', 'Confidence Interval'))
120
+
121
+ # Add historical data
122
+ fig.add_trace(
123
+ go.Scatter(x=hist_data.index, y=hist_data['Close'], name='Historical Price',
124
+ line=dict(color='blue')),
125
+ row=1, col=1
126
+ )
127
+
128
+ # Add prediction mean
129
+ fig.add_trace(
130
+ go.Scatter(x=pred_dates, y=mean_pred, name='Predicted Price',
131
+ line=dict(color='red')),
132
+ row=1, col=1
133
+ )
134
+
135
+ # Add confidence intervals
136
+ fig.add_trace(
137
+ go.Scatter(x=pred_dates, y=mean_pred + 1.96 * std_pred,
138
+ fill=None, mode='lines', line_color='rgba(255,0,0,0.2)',
139
+ name='Upper Bound'),
140
+ row=1, col=1
141
+ )
142
+ fig.add_trace(
143
+ go.Scatter(x=pred_dates, y=mean_pred - 1.96 * std_pred,
144
+ fill='tonexty', mode='lines', line_color='rgba(255,0,0,0.2)',
145
+ name='Lower Bound'),
146
+ row=1, col=1
147
+ )
148
+
149
+ # Add confidence interval plot
150
+ fig.add_trace(
151
+ go.Scatter(x=pred_dates, y=std_pred, name='Prediction Uncertainty',
152
+ line=dict(color='green')),
153
+ row=2, col=1
154
+ )
155
+
156
+ # Update layout
157
+ fig.update_layout(
158
+ title=f'{symbol} Price Prediction',
159
+ xaxis_title='Date',
160
+ yaxis_title='Price',
161
+ height=800,
162
+ showlegend=True
163
+ )
164
+
165
+ return {
166
+ "symbol": symbol,
167
+ "prediction": mean_pred.tolist(),
168
+ "confidence": std_pred.tolist(),
169
+ "dates": pred_dates.strftime('%Y-%m-%d').tolist(),
170
+ "plot": fig
171
+ }
172
+
173
+ except Exception as e:
174
+ raise Exception(f"Prediction error: {str(e)}")
175
+
176
+ # Create Gradio interface
177
+ def create_interface():
178
+ with gr.Blocks(title="Stock Price Prediction with Amazon Chronos") as demo:
179
+ gr.Markdown("# Stock Price Prediction with Amazon Chronos")
180
+ gr.Markdown("Enter a stock symbol and select prediction parameters to get price forecasts.")
181
+
182
+ with gr.Row():
183
+ with gr.Column():
184
+ symbol = gr.Textbox(label="Stock Symbol (e.g., AAPL)", value="AAPL")
185
+ timeframe = gr.Dropdown(
186
+ choices=["1d", "1h", "15m"],
187
+ label="Timeframe",
188
+ value="1d"
189
+ )
190
+ prediction_days = gr.Slider(
191
+ minimum=1,
192
+ maximum=30,
193
+ value=5,
194
+ step=1,
195
+ label="Days to Predict"
196
+ )
197
+ predict_btn = gr.Button("Make Prediction")
198
+
199
+ with gr.Column():
200
+ plot = gr.Plot(label="Prediction Visualization")
201
+ results = gr.JSON(label="Prediction Results")
202
+
203
+ predict_btn.click(
204
+ fn=make_prediction,
205
+ inputs=[symbol, timeframe, prediction_days],
206
+ outputs=[results, plot]
207
+ )
208
+
209
+ return demo
210
+
211
+ if __name__ == "__main__":
212
+ demo = create_interface()
213
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # first pip install numpy scipy scikit-learn all seperately
2
+
3
+
4
+ pandas_datareader
5
+ # numpy
6
+ #--find-links https://download.pytorch.org/whl/torch_stable.html
7
+ #torch==1.11.0+cu113
8
+ # torch==1.13.0 breaks due to _posthooks
9
+ torch>=2.1.2
10
+ #torchvision==0.10.0+cu111
11
+
12
+ pandas>=2.0.0
13
+ # scipy
14
+ loguru>=0.7.0
15
+ matplotlib
16
+ neuralforecast
17
+ retry>=0.9.2
18
+ hyperopt
19
+ #neuralprophet
20
+ alpaca-trade-api>=3.0.0
21
+
22
+ SQLAlchemy
23
+ websocket-client
24
+ py
25
+ future
26
+ Pillow
27
+ ipython
28
+ pbr
29
+ setuptools
30
+ six
31
+ wheel
32
+ pip
33
+ tqdm
34
+ optuna
35
+ # scikit-learn
36
+ filelock
37
+ transformers>=4.36.0
38
+ click
39
+ requests>=2.31.0
40
+ joblib
41
+ aiohttp
42
+ tensorboard
43
+ msgpack
44
+ urllib3
45
+ rsa
46
+ pyasn1
47
+ attrs
48
+ wcwidth
49
+ cmd2
50
+ pyperclip
51
+ fsspec
52
+ packaging
53
+ parso
54
+ jedi
55
+ lxml
56
+ Mako
57
+ MarkupSafe
58
+ pytz>=2023.3
59
+ toml
60
+ idna
61
+ multidict
62
+ cliff
63
+ stevedore
64
+ autopage
65
+ prettytable
66
+ certifi
67
+ patsy
68
+ regex
69
+ cachetools>=5.3.0
70
+ python-dateutil
71
+ cmaes
72
+ alembic
73
+ colorlog
74
+ traitlets
75
+ decorator
76
+ backcall
77
+ pickleshare
78
+ pluggy
79
+ iniconfig
80
+ yarl
81
+ chardet
82
+ threadpoolctl
83
+ greenlet
84
+ Markdown
85
+ oauthlib
86
+ Werkzeug
87
+ fonttools
88
+ pyparsing
89
+ websockets
90
+ statsmodels
91
+ # cycler
92
+ # kiwisolver
93
+ # sacremoses
94
+ tokenizers #==0.15.2 --only-binary=:all:
95
+ # torchmetrics
96
+ # zipp
97
+ # typer
98
+ #pytorch-forecasting
99
+ # pytorch-forecasting
100
+ #pytorch-lightning
101
+ alpaca-py>=0.8.0
102
+ fastapi
103
+ gunicorn
104
+ uvicorn
105
+ # git+https://github.com/amazon-science/chronos-forecasting.git
106
+ chronos-forecasting
107
+ scikit-learn
108
+
109
+ python-binance
110
+ typer
111
+ diskcache
112
+ anthropic
113
+ gradio>=4.0.0
114
+ spaces>=0.1.0
115
+ numpy>=1.24.0
116
+ torch>=2.0.0
117
+ yfinance>=0.2.0
118
+ plotly>=5.0.0
119
+ chronos>=0.1.0