import warnings warnings.filterwarnings("ignore") import gradio as gr import numpy as np import pandas as pd import yfinance as yf import matplotlib.pyplot as plt from pandas.tseries.frequencies import to_offset from gluonts.dataset.common import ListDataset # --- Moirai 2.0 via Uni2TS --- # Make sure your requirements install Uni2TS from GitHub: # git+https://github.com/SalesforceAIResearch/uni2ts.git try: from uni2ts.model.moirai2 import Moirai2Forecast, Moirai2Module except Exception as e: raise ImportError( "Moirai 2.0 not found in your Uni2TS install.\n" "Ensure requirements.txt includes:\n" " git+https://github.com/SalesforceAIResearch/uni2ts.git\n" f"Original error: {e}" ) MODEL_ID = "Salesforce/moirai-2.0-R-small" DEFAULT_CONTEXT = 1680 # from Moirai examples, but we clamp to series length # ---------------------------- # Model loader (single instance) # ---------------------------- _MODULE = None def load_module(): global _MODULE if _MODULE is None: _MODULE = Moirai2Module.from_pretrained(MODEL_ID) return _MODULE # ---------------------------- # Shared forecasting core # ---------------------------- def _future_index(last_idx: pd.Timestamp, freq: str, horizon: int) -> pd.DatetimeIndex: off = to_offset(freq) start = last_idx + off return pd.date_range(start=start, periods=horizon, freq=freq) def _run_forecast_on_series( y: pd.Series, freq: str, horizon: int, context_hint: int, title: str, ): if len(y) < 50: raise gr.Error("Need at least 50 points to forecast.") ctx = int(np.clip(context_hint or DEFAULT_CONTEXT, 32, len(y))) target = y.values[-ctx:].astype(np.float32) start_idx = y.index[-ctx] ds = ListDataset([{"start": start_idx, "target": target}], freq=freq) module = load_module() model = Moirai2Forecast( module=module, prediction_length=int(horizon), context_length=ctx, target_dim=1, feat_dynamic_real_dim=0, past_feat_dynamic_real_dim=0, ) predictor = model.create_predictor(batch_size=32) # device handled internally forecast = next(iter(predictor.predict(ds))) if hasattr(forecast, "mean"): yhat = np.asarray(forecast.mean) elif hasattr(forecast, "quantile"): yhat = np.asarray(forecast.quantile(0.5)) elif hasattr(forecast, "samples"): yhat = np.asarray(forecast.samples).mean(axis=0) else: yhat = np.asarray(forecast) yhat = np.asarray(yhat).ravel()[:horizon] future_idx = _future_index(y.index[-1], freq, horizon) pred = pd.Series(yhat, index=future_idx, name="prediction") # Plot fig = plt.figure(figsize=(10, 5)) plt.plot(y.index, y.values, label="history") plt.plot(pred.index, pred.values, label="forecast") plt.title(title) plt.xlabel("Time"); plt.ylabel("Value"); plt.legend(); plt.tight_layout() out_df = pd.DataFrame({"date": pred.index, "prediction": pred.values}) return fig, out_df # ---------------------------- # Ticker helpers # ---------------------------- def fetch_series(ticker: str, years: int) -> pd.Series: """Fetch daily close prices and align to business-day frequency.""" data = yf.download( ticker, period=f"{years}y", interval="1d", auto_adjust=True, progress=False, threads=True, ) if data is None or data.empty: raise gr.Error(f"No price data found for '{ticker}'.") col = "Close" if "Close" in data.columns else ("Adj Close" if "Adj Close" in data.columns else None) if col is None: raise gr.Error(f"Unexpected columns from yfinance: {list(data.columns)}") if isinstance(data.columns, pd.MultiIndex): if ticker in data[col].columns: s = data[col][ticker] else: s = data[col].iloc[:, 0] else: s = data[col] y = s.copy() y.name = ticker y.index = pd.DatetimeIndex(y.index).tz_localize(None) # Business-day index; forward-fill holidays bidx = pd.bdate_range(y.index.min(), y.index.max()) y = y.reindex(bidx).ffill() if y.isna().all(): raise gr.Error(f"Only missing values for '{ticker}'.") return y def forecast_ticker(ticker: str, horizon: int, lookback_years: int, context_hint: int): ticker = (ticker or "").strip().upper() if not ticker: raise gr.Error("Please enter a ticker symbol (e.g., AAPL).") if horizon < 1: raise gr.Error("Forecast horizon must be at least 1.") y = fetch_series(ticker, lookback_years) return _run_forecast_on_series(y, "B", horizon, context_hint, f"{ticker} — forecast (Moirai 2.0 R-small)") # ---------------------------- # CSV helpers # ---------------------------- def _read_csv_columns(file_path: str) -> pd.DataFrame: try: df = pd.read_csv(file_path) except Exception: df = pd.read_csv(file_path, sep=None, engine="python") return df def _coerce_numeric_series(s: pd.Series) -> pd.Series: s = pd.to_numeric(s, errors="coerce") return s.dropna().astype(np.float32) def build_series_from_csv(file, value_col: str, date_col: str, freq_choice: str): """ Returns (series y with DateTimeIndex, freq string). - If date_col is provided: parse dates and infer/align frequency. - If NO date_col: create a synthetic date index using freq_choice (default to 'D' if auto/blank). """ if file is None: raise gr.Error("Please upload a CSV file.") # Gradio file object handling (v4/v5) path = getattr(file, "name", None) or getattr(file, "path", None) or (file if isinstance(file, str) else None) if path is None: raise gr.Error("Could not read the uploaded file path.") df = _read_csv_columns(path) if df.empty: raise gr.Error("Uploaded file is empty.") # Value column selection value_col = (value_col or "").strip() if value_col: if value_col not in df.columns: raise gr.Error(f"Value column '{value_col}' not found. Available: {list(df.columns)}") vals = _coerce_numeric_series(df[value_col]) else: numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])] if numeric_cols: vals = _coerce_numeric_series(df[numeric_cols[0]]) else: vals = _coerce_numeric_series(df.iloc[:, 0]) if vals.empty or len(vals) < 10: raise gr.Error("Not enough numeric values after parsing (need at least 10).") date_col = (date_col or "").strip() freq_choice_norm = (freq_choice or "").strip().upper() if date_col: if date_col not in df.columns: raise gr.Error(f"Date column '{date_col}' not found. Available: {list(df.columns)}") dt = pd.to_datetime(df[date_col], errors="coerce") mask = dt.notna() & vals.notna() dt = pd.DatetimeIndex(dt[mask]).tz_localize(None) vals = vals[mask] if len(vals) < 10: raise gr.Error("Too few valid rows after parsing date/value columns.") # Sort & dedupe index BEFORE inferring/aligning freq order = np.argsort(dt.values) dt = dt[order] vals = vals.iloc[order].reset_index(drop=True) y = pd.Series(vals.values, index=dt, name=value_col or "value").copy() y = y[~y.index.duplicated(keep="last")].sort_index() # Choose frequency if freq_choice_norm and freq_choice_norm != "AUTO": freq = freq_choice_norm else: inferred = pd.infer_freq(y.index) if inferred: freq = inferred else: weekday_ratio = (y.index.dayofweek < 5).mean() freq = "B" if weekday_ratio > 0.95 else "D" # Align to chosen frequency y = y.asfreq(freq, method="ffill") else: # No date column: build synthetic index freq = "D" if (not freq_choice_norm or freq_choice_norm == "AUTO") else freq_choice_norm idx = pd.date_range(start="2000-01-01", periods=len(vals), freq=freq) y = pd.Series(vals.values, index=idx, name=value_col or "value").copy() if y.isna().all(): raise gr.Error("Series is all-NaN after processing.") return y, freq def forecast_csv(file, value_col: str, date_col: str, freq_choice: str, horizon: int, context_hint: int): y, freq = build_series_from_csv(file, value_col, date_col, freq_choice) return _run_forecast_on_series(y, freq, horizon, context_hint, f"Uploaded series — forecast (freq={freq})") # ---------------------------- # UI # ---------------------------- with gr.Blocks(title="Moirai 2.0 — Time Series Forecast (Research)") as demo: gr.Markdown( """ # Moirai 2.0 — Time Series Forecast (Research) Use **Salesforce/moirai-2.0-R-small** (via Uni2TS) to forecast either a stock ticker *or* a generic CSV time series. > **Important**: Research/educational use only. Not investment advice. Model license: **CC-BY-NC-4.0 (non-commercial)**. """ ) with gr.Tab("By Ticker"): with gr.Row(): ticker = gr.Textbox(label="Ticker", value="AAPL", placeholder="e.g., AAPL, MSFT, TSLA") horizon_t = gr.Slider(5, 120, value=30, step=1, label="Forecast horizon (steps)") with gr.Row(): lookback = gr.Slider(1, 10, value=5, step=1, label="Lookback window (years of history)") ctx_t = gr.Slider(64, 5000, value=1680, step=16, label="Context length") run_t = gr.Button("Run forecast", variant="primary") plot_t = gr.Plot(label="History + Forecast") table_t = gr.Dataframe(label="Forecast table", interactive=False) run_t.click(forecast_ticker, inputs=[ticker, horizon_t, lookback, ctx_t], outputs=[plot_t, table_t]) with gr.Tab("Upload CSV"): gr.Markdown( "Upload a CSV with either (1) a **date/time column** and a **value column**, " "or (2) just a numeric value column (then choose a frequency, or leave **auto** to default to **D**)." ) with gr.Row(): file = gr.File(label="CSV file", file_types=[".csv"]) with gr.Row(): date_col = gr.Textbox(label="Date/time column (optional)", placeholder="e.g., date, timestamp") value_col = gr.Textbox(label="Value column (optional — auto-detects first numeric)", placeholder="e.g., value, close") with gr.Row(): freq_choice = gr.Dropdown( label="Frequency", value="auto", choices=["auto", "B", "D", "H", "W", "M", "MS"], info="If no date column, 'auto' defaults to D (daily)." ) with gr.Row(): horizon_u = gr.Slider(1, 500, value=60, step=1, label="Forecast horizon (steps)") ctx_u = gr.Slider(32, 5000, value=512, step=16, label="Context length") run_u = gr.Button("Run forecast on CSV", variant="primary") plot_u = gr.Plot(label="History + Forecast (CSV)") table_u = gr.Dataframe(label="Forecast table (CSV)", interactive=False) run_u.click( forecast_csv, inputs=[file, value_col, date_col, freq_choice, horizon_u, ctx_u], outputs=[plot_u, table_u], ) if __name__ == "__main__": demo.launch()