|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import pandas as pd |
|
import yfinance as yf |
|
import matplotlib.pyplot as plt |
|
|
|
from pandas.tseries.frequencies import to_offset |
|
from gluonts.dataset.common import ListDataset |
|
|
|
|
|
|
|
|
|
try: |
|
from uni2ts.model.moirai2 import Moirai2Forecast, Moirai2Module |
|
except Exception as e: |
|
raise ImportError( |
|
"Moirai 2.0 not found in your Uni2TS install.\n" |
|
"Ensure requirements.txt includes:\n" |
|
" git+https://github.com/SalesforceAIResearch/uni2ts.git\n" |
|
f"Original error: {e}" |
|
) |
|
|
|
MODEL_ID = "Salesforce/moirai-2.0-R-small" |
|
DEFAULT_CONTEXT = 1680 |
|
|
|
|
|
|
|
|
|
_MODULE = None |
|
def load_module(): |
|
global _MODULE |
|
if _MODULE is None: |
|
_MODULE = Moirai2Module.from_pretrained(MODEL_ID) |
|
return _MODULE |
|
|
|
|
|
|
|
|
|
def _future_index(last_idx: pd.Timestamp, freq: str, horizon: int) -> pd.DatetimeIndex: |
|
off = to_offset(freq) |
|
start = last_idx + off |
|
return pd.date_range(start=start, periods=horizon, freq=freq) |
|
|
|
def _run_forecast_on_series( |
|
y: pd.Series, |
|
freq: str, |
|
horizon: int, |
|
context_hint: int, |
|
title: str, |
|
): |
|
if len(y) < 50: |
|
raise gr.Error("Need at least 50 points to forecast.") |
|
|
|
ctx = int(np.clip(context_hint or DEFAULT_CONTEXT, 32, len(y))) |
|
target = y.values[-ctx:].astype(np.float32) |
|
start_idx = y.index[-ctx] |
|
|
|
ds = ListDataset([{"start": start_idx, "target": target}], freq=freq) |
|
|
|
module = load_module() |
|
model = Moirai2Forecast( |
|
module=module, |
|
prediction_length=int(horizon), |
|
context_length=ctx, |
|
target_dim=1, |
|
feat_dynamic_real_dim=0, |
|
past_feat_dynamic_real_dim=0, |
|
) |
|
predictor = model.create_predictor(batch_size=32) |
|
|
|
forecast = next(iter(predictor.predict(ds))) |
|
if hasattr(forecast, "mean"): |
|
yhat = np.asarray(forecast.mean) |
|
elif hasattr(forecast, "quantile"): |
|
yhat = np.asarray(forecast.quantile(0.5)) |
|
elif hasattr(forecast, "samples"): |
|
yhat = np.asarray(forecast.samples).mean(axis=0) |
|
else: |
|
yhat = np.asarray(forecast) |
|
|
|
yhat = np.asarray(yhat).ravel()[:horizon] |
|
future_idx = _future_index(y.index[-1], freq, horizon) |
|
pred = pd.Series(yhat, index=future_idx, name="prediction") |
|
|
|
|
|
fig = plt.figure(figsize=(10, 5)) |
|
plt.plot(y.index, y.values, label="history") |
|
plt.plot(pred.index, pred.values, label="forecast") |
|
plt.title(title) |
|
plt.xlabel("Time"); plt.ylabel("Value"); plt.legend(); plt.tight_layout() |
|
|
|
out_df = pd.DataFrame({"date": pred.index, "prediction": pred.values}) |
|
return fig, out_df |
|
|
|
|
|
|
|
|
|
def fetch_series(ticker: str, years: int) -> pd.Series: |
|
"""Fetch daily close prices and align to business-day frequency.""" |
|
data = yf.download( |
|
ticker, |
|
period=f"{years}y", |
|
interval="1d", |
|
auto_adjust=True, |
|
progress=False, |
|
threads=True, |
|
) |
|
if data is None or data.empty: |
|
raise gr.Error(f"No price data found for '{ticker}'.") |
|
|
|
col = "Close" if "Close" in data.columns else ("Adj Close" if "Adj Close" in data.columns else None) |
|
if col is None: |
|
raise gr.Error(f"Unexpected columns from yfinance: {list(data.columns)}") |
|
|
|
if isinstance(data.columns, pd.MultiIndex): |
|
if ticker in data[col].columns: |
|
s = data[col][ticker] |
|
else: |
|
s = data[col].iloc[:, 0] |
|
else: |
|
s = data[col] |
|
|
|
y = s.copy() |
|
y.name = ticker |
|
y.index = pd.DatetimeIndex(y.index).tz_localize(None) |
|
|
|
|
|
bidx = pd.bdate_range(y.index.min(), y.index.max()) |
|
y = y.reindex(bidx).ffill() |
|
|
|
if y.isna().all(): |
|
raise gr.Error(f"Only missing values for '{ticker}'.") |
|
return y |
|
|
|
def forecast_ticker(ticker: str, horizon: int, lookback_years: int, context_hint: int): |
|
ticker = (ticker or "").strip().upper() |
|
if not ticker: |
|
raise gr.Error("Please enter a ticker symbol (e.g., AAPL).") |
|
if horizon < 1: |
|
raise gr.Error("Forecast horizon must be at least 1.") |
|
y = fetch_series(ticker, lookback_years) |
|
return _run_forecast_on_series(y, "B", horizon, context_hint, f"{ticker} β forecast (Moirai 2.0 R-small)") |
|
|
|
|
|
|
|
|
|
def _read_csv_columns(file_path: str) -> pd.DataFrame: |
|
try: |
|
df = pd.read_csv(file_path) |
|
except Exception: |
|
df = pd.read_csv(file_path, sep=None, engine="python") |
|
return df |
|
|
|
def _coerce_numeric_series(s: pd.Series) -> pd.Series: |
|
s = pd.to_numeric(s, errors="coerce") |
|
return s.dropna().astype(np.float32) |
|
|
|
def build_series_from_csv(file, value_col: str, date_col: str, freq_choice: str): |
|
""" |
|
Returns (series y with DateTimeIndex, freq string). |
|
- If date_col is provided: parse dates and infer/align frequency. |
|
- If NO date_col: create a synthetic date index using freq_choice (default to 'D' if auto/blank). |
|
""" |
|
if file is None: |
|
raise gr.Error("Please upload a CSV file.") |
|
|
|
|
|
path = getattr(file, "name", None) or getattr(file, "path", None) or (file if isinstance(file, str) else None) |
|
if path is None: |
|
raise gr.Error("Could not read the uploaded file path.") |
|
|
|
df = _read_csv_columns(path) |
|
if df.empty: |
|
raise gr.Error("Uploaded file is empty.") |
|
|
|
|
|
value_col = (value_col or "").strip() |
|
if value_col: |
|
if value_col not in df.columns: |
|
raise gr.Error(f"Value column '{value_col}' not found. Available: {list(df.columns)}") |
|
vals = _coerce_numeric_series(df[value_col]) |
|
else: |
|
numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])] |
|
if numeric_cols: |
|
vals = _coerce_numeric_series(df[numeric_cols[0]]) |
|
else: |
|
vals = _coerce_numeric_series(df.iloc[:, 0]) |
|
|
|
if vals.empty or len(vals) < 10: |
|
raise gr.Error("Not enough numeric values after parsing (need at least 10).") |
|
|
|
date_col = (date_col or "").strip() |
|
freq_choice_norm = (freq_choice or "").strip().upper() |
|
|
|
if date_col: |
|
if date_col not in df.columns: |
|
raise gr.Error(f"Date column '{date_col}' not found. Available: {list(df.columns)}") |
|
dt = pd.to_datetime(df[date_col], errors="coerce") |
|
mask = dt.notna() & vals.notna() |
|
dt = pd.DatetimeIndex(dt[mask]).tz_localize(None) |
|
vals = vals[mask] |
|
|
|
if len(vals) < 10: |
|
raise gr.Error("Too few valid rows after parsing date/value columns.") |
|
|
|
|
|
order = np.argsort(dt.values) |
|
dt = dt[order] |
|
vals = vals.iloc[order].reset_index(drop=True) |
|
|
|
y = pd.Series(vals.values, index=dt, name=value_col or "value").copy() |
|
y = y[~y.index.duplicated(keep="last")].sort_index() |
|
|
|
|
|
if freq_choice_norm and freq_choice_norm != "AUTO": |
|
freq = freq_choice_norm |
|
else: |
|
inferred = pd.infer_freq(y.index) |
|
if inferred: |
|
freq = inferred |
|
else: |
|
weekday_ratio = (y.index.dayofweek < 5).mean() |
|
freq = "B" if weekday_ratio > 0.95 else "D" |
|
|
|
|
|
y = y.asfreq(freq, method="ffill") |
|
|
|
else: |
|
|
|
freq = "D" if (not freq_choice_norm or freq_choice_norm == "AUTO") else freq_choice_norm |
|
idx = pd.date_range(start="2000-01-01", periods=len(vals), freq=freq) |
|
y = pd.Series(vals.values, index=idx, name=value_col or "value").copy() |
|
|
|
if y.isna().all(): |
|
raise gr.Error("Series is all-NaN after processing.") |
|
return y, freq |
|
|
|
def forecast_csv(file, value_col: str, date_col: str, freq_choice: str, horizon: int, context_hint: int): |
|
y, freq = build_series_from_csv(file, value_col, date_col, freq_choice) |
|
return _run_forecast_on_series(y, freq, horizon, context_hint, f"Uploaded series β forecast (freq={freq})") |
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Moirai 2.0 β Time Series Forecast (Research)") as demo: |
|
gr.Markdown( |
|
""" |
|
# Moirai 2.0 β Time Series Forecast (Research) |
|
Use **Salesforce/moirai-2.0-R-small** (via Uni2TS) to forecast either a stock ticker *or* a generic CSV time series. |
|
|
|
> **Important**: Research/educational use only. Not investment advice. Model license: **CC-BY-NC-4.0 (non-commercial)**. |
|
""" |
|
) |
|
|
|
with gr.Tab("By Ticker"): |
|
with gr.Row(): |
|
ticker = gr.Textbox(label="Ticker", value="AAPL", placeholder="e.g., AAPL, MSFT, TSLA") |
|
horizon_t = gr.Slider(5, 120, value=30, step=1, label="Forecast horizon (steps)") |
|
with gr.Row(): |
|
lookback = gr.Slider(1, 10, value=5, step=1, label="Lookback window (years of history)") |
|
ctx_t = gr.Slider(64, 5000, value=1680, step=16, label="Context length") |
|
run_t = gr.Button("Run forecast", variant="primary") |
|
plot_t = gr.Plot(label="History + Forecast") |
|
table_t = gr.Dataframe(label="Forecast table", interactive=False) |
|
run_t.click(forecast_ticker, inputs=[ticker, horizon_t, lookback, ctx_t], outputs=[plot_t, table_t]) |
|
|
|
with gr.Tab("Upload CSV"): |
|
gr.Markdown( |
|
"Upload a CSV with either (1) a **date/time column** and a **value column**, " |
|
"or (2) just a numeric value column (then choose a frequency, or leave **auto** to default to **D**)." |
|
) |
|
with gr.Row(): |
|
file = gr.File(label="CSV file", file_types=[".csv"]) |
|
with gr.Row(): |
|
date_col = gr.Textbox(label="Date/time column (optional)", placeholder="e.g., date, timestamp") |
|
value_col = gr.Textbox(label="Value column (optional β auto-detects first numeric)", placeholder="e.g., value, close") |
|
with gr.Row(): |
|
freq_choice = gr.Dropdown( |
|
label="Frequency", |
|
value="auto", |
|
choices=["auto", "B", "D", "H", "W", "M", "MS"], |
|
info="If no date column, 'auto' defaults to D (daily)." |
|
) |
|
with gr.Row(): |
|
horizon_u = gr.Slider(1, 500, value=60, step=1, label="Forecast horizon (steps)") |
|
ctx_u = gr.Slider(32, 5000, value=512, step=16, label="Context length") |
|
run_u = gr.Button("Run forecast on CSV", variant="primary") |
|
plot_u = gr.Plot(label="History + Forecast (CSV)") |
|
table_u = gr.Dataframe(label="Forecast table (CSV)", interactive=False) |
|
run_u.click( |
|
forecast_csv, |
|
inputs=[file, value_col, date_col, freq_choice, horizon_u, ctx_u], |
|
outputs=[plot_u, table_u], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
|