Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import warnings
|
| 3 |
+
warnings.filterwarnings("ignore")
|
| 4 |
+
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import numpy as np
|
| 8 |
+
import yfinance as yf
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from gluonts.dataset.common import ListDataset
|
| 13 |
+
|
| 14 |
+
# Moirai 2.0 via Uni2TS (per Salesforce's example)
|
| 15 |
+
# https://www.salesforce.com/blog/moirai-2-0/
|
| 16 |
+
from uni2ts.model.moirai2 import Moirai2Forecast, Moirai2Module # type: ignore
|
| 17 |
+
|
| 18 |
+
MODEL_ID = "Salesforce/moirai-2.0-R-small"
|
| 19 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 20 |
+
|
| 21 |
+
# Load the Moirai 2.0 module once at startup
|
| 22 |
+
_MODULE = None
|
| 23 |
+
def load_module():
|
| 24 |
+
global _MODULE
|
| 25 |
+
if _MODULE is None:
|
| 26 |
+
_MODULE = Moirai2Module.from_pretrained(MODEL_ID)
|
| 27 |
+
return _MODULE
|
| 28 |
+
|
| 29 |
+
def fetch_series(ticker: str, years: int) -> pd.Series:
|
| 30 |
+
"""Fetch daily close price and align to business-day frequency."""
|
| 31 |
+
data = yf.download(
|
| 32 |
+
ticker,
|
| 33 |
+
period=f"{years}y",
|
| 34 |
+
interval="1d",
|
| 35 |
+
auto_adjust=True,
|
| 36 |
+
progress=False,
|
| 37 |
+
threads=True,
|
| 38 |
+
)
|
| 39 |
+
if data is None or data.empty:
|
| 40 |
+
raise gr.Error(f"No price data found for '{ticker}'.")
|
| 41 |
+
# Prefer 'Close' after auto_adjust; fall back to 'Adj Close' if needed
|
| 42 |
+
col = "Close" if "Close" in data.columns else "Adj Close"
|
| 43 |
+
y = data[col].rename(ticker)
|
| 44 |
+
y.index = pd.DatetimeIndex(y.index).tz_localize(None)
|
| 45 |
+
|
| 46 |
+
# Business-day index, forward-fill market holidays
|
| 47 |
+
bidx = pd.bdate_range(y.index.min(), y.index.max())
|
| 48 |
+
y = y.reindex(bidx).ffill()
|
| 49 |
+
if y.isna().all():
|
| 50 |
+
raise gr.Error(f"Only missing values for '{ticker}'.")
|
| 51 |
+
return y
|
| 52 |
+
|
| 53 |
+
def forecast_ticker(ticker: str,
|
| 54 |
+
horizon: int,
|
| 55 |
+
lookback_years: int,
|
| 56 |
+
context_hint: int):
|
| 57 |
+
ticker = (ticker or "").strip().upper()
|
| 58 |
+
if not ticker:
|
| 59 |
+
raise gr.Error("Please enter a ticker symbol (e.g., AAPL).")
|
| 60 |
+
if horizon < 1:
|
| 61 |
+
raise gr.Error("Forecast horizon must be at least 1.")
|
| 62 |
+
|
| 63 |
+
# 1) Get history
|
| 64 |
+
y = fetch_series(ticker, lookback_years)
|
| 65 |
+
if len(y) < 50:
|
| 66 |
+
raise gr.Error("Not enough history to forecast (need at least 50 points).")
|
| 67 |
+
|
| 68 |
+
# 2) Build dataset for GluonTS-style predictor
|
| 69 |
+
# Use business-day freq ('B'); pick a context <= history length.
|
| 70 |
+
default_ctx = 1680 # from Moirai 2.0 examples
|
| 71 |
+
ctx = int(np.clip(context_hint or default_ctx, 32, len(y)))
|
| 72 |
+
target = y.values[-ctx:]
|
| 73 |
+
start_idx = y.index[-ctx]
|
| 74 |
+
|
| 75 |
+
ds = ListDataset([{"start": start_idx, "target": target}], freq="B")
|
| 76 |
+
|
| 77 |
+
# 3) Create forecast wrapper and predictor
|
| 78 |
+
module = load_module()
|
| 79 |
+
model = Moirai2Forecast(
|
| 80 |
+
module=module,
|
| 81 |
+
prediction_length=int(horizon),
|
| 82 |
+
context_length=ctx,
|
| 83 |
+
target_dim=1,
|
| 84 |
+
feat_dynamic_real_dim=0,
|
| 85 |
+
past_feat_dynamic_real_dim=0,
|
| 86 |
+
)
|
| 87 |
+
predictor = model.create_predictor(batch_size=32, device=DEVICE)
|
| 88 |
+
|
| 89 |
+
# 4) Predict
|
| 90 |
+
forecast = next(iter(predictor.predict(ds)))
|
| 91 |
+
|
| 92 |
+
# 5) Extract a reasonable central estimate
|
| 93 |
+
if hasattr(forecast, "mean"):
|
| 94 |
+
yhat = np.asarray(forecast.mean)
|
| 95 |
+
elif hasattr(forecast, "quantile"):
|
| 96 |
+
# 50th percentile as point
|
| 97 |
+
yhat = np.asarray(forecast.quantile(0.5))
|
| 98 |
+
elif hasattr(forecast, "samples"):
|
| 99 |
+
yhat = np.asarray(forecast.samples).mean(axis=0)
|
| 100 |
+
else:
|
| 101 |
+
# very defensive fallback
|
| 102 |
+
yhat = np.asarray(forecast)
|
| 103 |
+
|
| 104 |
+
# Guard length (some forecast objects can be slightly longer)
|
| 105 |
+
yhat = np.asarray(yhat).ravel()[:horizon]
|
| 106 |
+
|
| 107 |
+
# 6) Assemble dates & outputs
|
| 108 |
+
# Next business days after the last historical date
|
| 109 |
+
future_idx = pd.bdate_range(y.index[-1] + pd.tseries.offsets.BDay(), periods=horizon)
|
| 110 |
+
pred = pd.Series(yhat, index=future_idx, name="predicted_close")
|
| 111 |
+
|
| 112 |
+
# 7) Plot
|
| 113 |
+
fig = plt.figure(figsize=(10, 5))
|
| 114 |
+
plt.plot(y.index, y.values, label="history")
|
| 115 |
+
plt.plot(pred.index, pred.values, label="forecast")
|
| 116 |
+
plt.title(f"{ticker} close price forecast (Moirai 2.0 R-small)")
|
| 117 |
+
plt.xlabel("Date"); plt.ylabel("Price"); plt.legend(); plt.tight_layout()
|
| 118 |
+
|
| 119 |
+
# 8) Table
|
| 120 |
+
out_df = pd.DataFrame({"date": pred.index, "predicted_close": pred.values})
|
| 121 |
+
return fig, out_df
|
| 122 |
+
|
| 123 |
+
with gr.Blocks(title="Moirai 2.0 — Stock Price Forecast (Research)") as demo:
|
| 124 |
+
gr.Markdown(
|
| 125 |
+
"""
|
| 126 |
+
# Moirai 2.0 — Stock Price Forecast (Research)
|
| 127 |
+
Enter a ticker to fetch recent daily prices and generate a short-term forecast using **Salesforce/moirai-2.0-R-small**.
|
| 128 |
+
> **Important**: For **research/educational** use only. Not investment advice. Model license is **CC-BY-NC-4.0 (non-commercial)**.
|
| 129 |
+
"""
|
| 130 |
+
)
|
| 131 |
+
with gr.Row():
|
| 132 |
+
ticker = gr.Textbox(label="Ticker", value="AAPL", placeholder="e.g., AAPL, MSFT, TSLA")
|
| 133 |
+
horizon = gr.Slider(5, 120, value=30, step=1, label="Forecast horizon (business days)")
|
| 134 |
+
with gr.Row():
|
| 135 |
+
lookback = gr.Slider(1, 10, value=5, step=1, label="Lookback window (years of history)")
|
| 136 |
+
ctx = gr.Slider(64, 2000, value=1680, step=16, label="Context length (points)")
|
| 137 |
+
|
| 138 |
+
run = gr.Button("Run forecast", variant="primary")
|
| 139 |
+
plot = gr.Plot(label="History + Forecast")
|
| 140 |
+
table = gr.Dataframe(label="Forecast table", interactive=False)
|
| 141 |
+
|
| 142 |
+
run.click(forecast_ticker, inputs=[ticker, horizon, lookback, ctx], outputs=[plot, table])
|
| 143 |
+
|
| 144 |
+
if __name__ == "__main__":
|
| 145 |
+
demo.launch()
|