Vishwas1 commited on
Commit
a4f8e39
·
verified ·
1 Parent(s): beab7f5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -0
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import warnings
3
+ warnings.filterwarnings("ignore")
4
+
5
+ import gradio as gr
6
+ import pandas as pd
7
+ import numpy as np
8
+ import yfinance as yf
9
+ import matplotlib.pyplot as plt
10
+
11
+ import torch
12
+ from gluonts.dataset.common import ListDataset
13
+
14
+ # Moirai 2.0 via Uni2TS (per Salesforce's example)
15
+ # https://www.salesforce.com/blog/moirai-2-0/
16
+ from uni2ts.model.moirai2 import Moirai2Forecast, Moirai2Module # type: ignore
17
+
18
+ MODEL_ID = "Salesforce/moirai-2.0-R-small"
19
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ # Load the Moirai 2.0 module once at startup
22
+ _MODULE = None
23
+ def load_module():
24
+ global _MODULE
25
+ if _MODULE is None:
26
+ _MODULE = Moirai2Module.from_pretrained(MODEL_ID)
27
+ return _MODULE
28
+
29
+ def fetch_series(ticker: str, years: int) -> pd.Series:
30
+ """Fetch daily close price and align to business-day frequency."""
31
+ data = yf.download(
32
+ ticker,
33
+ period=f"{years}y",
34
+ interval="1d",
35
+ auto_adjust=True,
36
+ progress=False,
37
+ threads=True,
38
+ )
39
+ if data is None or data.empty:
40
+ raise gr.Error(f"No price data found for '{ticker}'.")
41
+ # Prefer 'Close' after auto_adjust; fall back to 'Adj Close' if needed
42
+ col = "Close" if "Close" in data.columns else "Adj Close"
43
+ y = data[col].rename(ticker)
44
+ y.index = pd.DatetimeIndex(y.index).tz_localize(None)
45
+
46
+ # Business-day index, forward-fill market holidays
47
+ bidx = pd.bdate_range(y.index.min(), y.index.max())
48
+ y = y.reindex(bidx).ffill()
49
+ if y.isna().all():
50
+ raise gr.Error(f"Only missing values for '{ticker}'.")
51
+ return y
52
+
53
+ def forecast_ticker(ticker: str,
54
+ horizon: int,
55
+ lookback_years: int,
56
+ context_hint: int):
57
+ ticker = (ticker or "").strip().upper()
58
+ if not ticker:
59
+ raise gr.Error("Please enter a ticker symbol (e.g., AAPL).")
60
+ if horizon < 1:
61
+ raise gr.Error("Forecast horizon must be at least 1.")
62
+
63
+ # 1) Get history
64
+ y = fetch_series(ticker, lookback_years)
65
+ if len(y) < 50:
66
+ raise gr.Error("Not enough history to forecast (need at least 50 points).")
67
+
68
+ # 2) Build dataset for GluonTS-style predictor
69
+ # Use business-day freq ('B'); pick a context <= history length.
70
+ default_ctx = 1680 # from Moirai 2.0 examples
71
+ ctx = int(np.clip(context_hint or default_ctx, 32, len(y)))
72
+ target = y.values[-ctx:]
73
+ start_idx = y.index[-ctx]
74
+
75
+ ds = ListDataset([{"start": start_idx, "target": target}], freq="B")
76
+
77
+ # 3) Create forecast wrapper and predictor
78
+ module = load_module()
79
+ model = Moirai2Forecast(
80
+ module=module,
81
+ prediction_length=int(horizon),
82
+ context_length=ctx,
83
+ target_dim=1,
84
+ feat_dynamic_real_dim=0,
85
+ past_feat_dynamic_real_dim=0,
86
+ )
87
+ predictor = model.create_predictor(batch_size=32, device=DEVICE)
88
+
89
+ # 4) Predict
90
+ forecast = next(iter(predictor.predict(ds)))
91
+
92
+ # 5) Extract a reasonable central estimate
93
+ if hasattr(forecast, "mean"):
94
+ yhat = np.asarray(forecast.mean)
95
+ elif hasattr(forecast, "quantile"):
96
+ # 50th percentile as point
97
+ yhat = np.asarray(forecast.quantile(0.5))
98
+ elif hasattr(forecast, "samples"):
99
+ yhat = np.asarray(forecast.samples).mean(axis=0)
100
+ else:
101
+ # very defensive fallback
102
+ yhat = np.asarray(forecast)
103
+
104
+ # Guard length (some forecast objects can be slightly longer)
105
+ yhat = np.asarray(yhat).ravel()[:horizon]
106
+
107
+ # 6) Assemble dates & outputs
108
+ # Next business days after the last historical date
109
+ future_idx = pd.bdate_range(y.index[-1] + pd.tseries.offsets.BDay(), periods=horizon)
110
+ pred = pd.Series(yhat, index=future_idx, name="predicted_close")
111
+
112
+ # 7) Plot
113
+ fig = plt.figure(figsize=(10, 5))
114
+ plt.plot(y.index, y.values, label="history")
115
+ plt.plot(pred.index, pred.values, label="forecast")
116
+ plt.title(f"{ticker} close price forecast (Moirai 2.0 R-small)")
117
+ plt.xlabel("Date"); plt.ylabel("Price"); plt.legend(); plt.tight_layout()
118
+
119
+ # 8) Table
120
+ out_df = pd.DataFrame({"date": pred.index, "predicted_close": pred.values})
121
+ return fig, out_df
122
+
123
+ with gr.Blocks(title="Moirai 2.0 — Stock Price Forecast (Research)") as demo:
124
+ gr.Markdown(
125
+ """
126
+ # Moirai 2.0 — Stock Price Forecast (Research)
127
+ Enter a ticker to fetch recent daily prices and generate a short-term forecast using **Salesforce/moirai-2.0-R-small**.
128
+ > **Important**: For **research/educational** use only. Not investment advice. Model license is **CC-BY-NC-4.0 (non-commercial)**.
129
+ """
130
+ )
131
+ with gr.Row():
132
+ ticker = gr.Textbox(label="Ticker", value="AAPL", placeholder="e.g., AAPL, MSFT, TSLA")
133
+ horizon = gr.Slider(5, 120, value=30, step=1, label="Forecast horizon (business days)")
134
+ with gr.Row():
135
+ lookback = gr.Slider(1, 10, value=5, step=1, label="Lookback window (years of history)")
136
+ ctx = gr.Slider(64, 2000, value=1680, step=16, label="Context length (points)")
137
+
138
+ run = gr.Button("Run forecast", variant="primary")
139
+ plot = gr.Plot(label="History + Forecast")
140
+ table = gr.Dataframe(label="Forecast table", interactive=False)
141
+
142
+ run.click(forecast_ticker, inputs=[ticker, horizon, lookback, ctx], outputs=[plot, table])
143
+
144
+ if __name__ == "__main__":
145
+ demo.launch()