Spaces:
Runtime error
Runtime error
| import os | |
| import math | |
| import tempfile | |
| import warnings | |
| import streamlit as st | |
| import pandas as pd | |
| import torch | |
| import plotly.express as px | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import OneCycleLR | |
| from transformers import ( | |
| EarlyStoppingCallback, | |
| Trainer, | |
| TrainingArguments, | |
| set_seed, | |
| ) | |
| from transformers.integrations import INTEGRATION_TO_CALLBACK | |
| from tsfm_public import ( | |
| TimeSeriesPreprocessor, | |
| TrackingCallback, | |
| count_parameters, | |
| get_datasets, | |
| ) | |
| from tsfm_public.toolkit.get_model import get_model | |
| from tsfm_public.toolkit.lr_finder import optimal_lr_finder | |
| from tsfm_public.toolkit.visualization import plot_predictions | |
| # For M4 Hourly Example | |
| from tsfm_public.models.tinytimemixer import TinyTimeMixerForPrediction | |
| # Suppress warnings and set a reproducible seed | |
| warnings.filterwarnings("ignore") | |
| SEED = 42 | |
| set_seed(SEED) | |
| # Default model parameters and output directory | |
| TTM_MODEL_PATH = "ibm-granite/granite-timeseries-ttm-r2" | |
| DEFAULT_CONTEXT_LENGTH = 512 | |
| DEFAULT_PREDICTION_LENGTH = 96 | |
| OUT_DIR = "dashboard_outputs" | |
| os.makedirs(OUT_DIR, exist_ok=True) | |
| # -------------------------- | |
| # Helper: Interactive Plot | |
| def interactive_plot(actual, forecast, title="Forecast vs Actual"): | |
| df = pd.DataFrame( | |
| {"Time": range(len(actual)), "Actual": actual, "Forecast": forecast} | |
| ) | |
| fig = px.line(df, x="Time", y=["Actual", "Forecast"], title=title) | |
| return fig | |
| # -------------------------- | |
| # Mode 1: Zero-shot Evaluation | |
| def run_zero_shot_forecasting( | |
| data, | |
| context_length, | |
| prediction_length, | |
| batch_size, | |
| selected_target_columns, | |
| selected_conditional_columns, | |
| rolling_forecast_extension, | |
| selected_forecast_index, | |
| ): | |
| st.write("### Preparing Data for Forecasting") | |
| timestamp_column = "date" | |
| id_columns = [] # Modify if needed. | |
| # Use selected target columns; default to all columns (except "date") if not provided. | |
| if not selected_target_columns: | |
| target_columns = [col for col in data.columns if col != timestamp_column] | |
| else: | |
| target_columns = selected_target_columns | |
| # Incorporate exogenous/control columns. | |
| conditional_columns = selected_conditional_columns | |
| # Define column specifiers (if your preprocessor supports static columns, add here) | |
| column_specifiers = { | |
| "timestamp_column": timestamp_column, | |
| "id_columns": id_columns, | |
| "target_columns": target_columns, | |
| "control_columns": conditional_columns, | |
| } | |
| n = len(data) | |
| split_config = { | |
| "train": [0, int(n * 0.7)], | |
| "valid": [int(n * 0.7), int(n * 0.8)], | |
| "test": [int(n * 0.8), n], | |
| } | |
| tsp = TimeSeriesPreprocessor( | |
| **column_specifiers, | |
| context_length=context_length, | |
| prediction_length=prediction_length, | |
| scaling=True, | |
| encode_categorical=False, | |
| scaler_type="standard", | |
| ) | |
| dset_train, dset_valid, dset_test = get_datasets(tsp, data, split_config) | |
| st.write("Data split into train, validation, and test sets.") | |
| st.write("### Loading the Pre-trained TTM Model") | |
| model = get_model( | |
| TTM_MODEL_PATH, | |
| context_length=context_length, | |
| prediction_length=prediction_length, | |
| ) | |
| temp_dir = tempfile.mkdtemp() | |
| training_args = TrainingArguments( | |
| output_dir=temp_dir, | |
| per_device_eval_batch_size=batch_size, | |
| seed=SEED, | |
| report_to="none", | |
| ) | |
| trainer = Trainer(model=model, args=training_args) | |
| st.write("### Running Zero-shot Evaluation") | |
| st.info("Evaluating on the test set...") | |
| eval_output = trainer.evaluate(dset_test) | |
| st.write("**Zero-shot Evaluation Metrics:**") | |
| st.json(eval_output) | |
| st.write("### Generating Forecast Predictions") | |
| predictions_dict = trainer.predict(dset_test) | |
| try: | |
| predictions_np = predictions_dict.predictions[0] | |
| except Exception as e: | |
| st.error("Error extracting predictions: " + str(e)) | |
| return | |
| st.write("Predictions shape:", predictions_np.shape) | |
| if rolling_forecast_extension > 0: | |
| st.write( | |
| f"### Rolling Forecast Extension: {rolling_forecast_extension} extra steps" | |
| ) | |
| st.info("Rolling forecast logic can be implemented here.") | |
| # Interactive plot for a selected forecast index. | |
| idx = selected_forecast_index | |
| try: | |
| # This example assumes dset_test[idx] is a dict with a "target" key; adjust as needed. | |
| actual = ( | |
| dset_test[idx]["target"] | |
| if isinstance(dset_test[idx], dict) | |
| else dset_test[idx][0] | |
| ) | |
| except Exception: | |
| actual = predictions_np[idx] # Fallback if actual is not available. | |
| fig = interactive_plot( | |
| actual, predictions_np[idx], title=f"Forecast vs Actual for index {idx}" | |
| ) | |
| st.plotly_chart(fig) | |
| # Static plots (generated via plot_predictions) | |
| plot_dir = os.path.join(OUT_DIR, "zero_shot_plots") | |
| os.makedirs(plot_dir, exist_ok=True) | |
| try: | |
| plot_predictions( | |
| model=trainer.model, | |
| dset=dset_test, | |
| plot_dir=plot_dir, | |
| plot_prefix="test_zeroshot", | |
| indices=[idx], | |
| channel=0, | |
| ) | |
| except Exception as e: | |
| st.error("Error during static plotting: " + str(e)) | |
| return | |
| for file in os.listdir(plot_dir): | |
| if file.endswith(".png"): | |
| st.image(os.path.join(plot_dir, file), caption=file) | |
| # -------------------------- | |
| # Mode 2: Channel-Mix Finetuning Example | |
| def run_channel_mix_finetuning(): | |
| st.write("## Channel-Mix Finetuning Example (Bike Sharing Data)") | |
| # Load bike sharing dataset | |
| target_dataset = "bike_sharing" | |
| DATA_ROOT_PATH = ( | |
| "https://raw.githubusercontent.com/blobibob/bike-sharing-dataset/main/hour.csv" | |
| ) | |
| timestamp_column = "dteday" | |
| id_columns = [] | |
| try: | |
| data = pd.read_csv(DATA_ROOT_PATH, parse_dates=[timestamp_column]) | |
| except Exception as e: | |
| st.error("Error loading bike sharing dataset: " + str(e)) | |
| return | |
| data[timestamp_column] = pd.to_datetime(data[timestamp_column]) | |
| # Adjust timestamps (to add hourly information) | |
| data[timestamp_column] = data[timestamp_column] + pd.to_timedelta( | |
| data.groupby(data[timestamp_column].dt.date).cumcount(), unit="h" | |
| ) | |
| st.write("### Bike Sharing Data Preview") | |
| st.dataframe(data.head()) | |
| # Define columns: targets and conditional (exogenous) channels | |
| column_specifiers = { | |
| "timestamp_column": timestamp_column, | |
| "id_columns": id_columns, | |
| "target_columns": ["casual", "registered", "cnt"], | |
| "conditional_columns": [ | |
| "season", | |
| "yr", | |
| "mnth", | |
| "holiday", | |
| "weekday", | |
| "workingday", | |
| "weathersit", | |
| "temp", | |
| "atemp", | |
| "hum", | |
| "windspeed", | |
| ], | |
| } | |
| n = len(data) | |
| split_config = { | |
| "train": [0, int(n * 0.5)], | |
| "valid": [int(n * 0.5), int(n * 0.75)], | |
| "test": [int(n * 0.75), n], | |
| } | |
| context_length = 512 | |
| forecast_length = 96 | |
| tsp = TimeSeriesPreprocessor( | |
| **column_specifiers, | |
| context_length=context_length, | |
| prediction_length=forecast_length, | |
| scaling=True, | |
| encode_categorical=False, | |
| scaler_type="standard", | |
| ) | |
| train_dataset, valid_dataset, test_dataset = get_datasets(tsp, data, split_config) | |
| st.write("Data split completed.") | |
| # For channel-mix finetuning, we use TTM-R1 (as per provided script) | |
| TTM_MODEL_PATH_CM = "ibm-granite/granite-timeseries-ttm-r1" | |
| finetune_forecast_model = get_model( | |
| TTM_MODEL_PATH_CM, | |
| context_length=context_length, | |
| prediction_length=forecast_length, | |
| num_input_channels=tsp.num_input_channels, | |
| decoder_mode="mix_channel", | |
| prediction_channel_indices=tsp.prediction_channel_indices, | |
| ) | |
| st.write( | |
| "Number of params before freezing backbone:", | |
| count_parameters(finetune_forecast_model), | |
| ) | |
| for param in finetune_forecast_model.backbone.parameters(): | |
| param.requires_grad = False | |
| st.write( | |
| "Number of params after freezing backbone:", | |
| count_parameters(finetune_forecast_model), | |
| ) | |
| num_epochs = 50 | |
| batch_size = 64 | |
| learning_rate = 0.001 | |
| optimizer = AdamW(finetune_forecast_model.parameters(), lr=learning_rate) | |
| scheduler = OneCycleLR( | |
| optimizer, | |
| learning_rate, | |
| epochs=num_epochs, | |
| steps_per_epoch=math.ceil(len(train_dataset) / batch_size), | |
| ) | |
| out_dir = os.path.join(OUT_DIR, target_dataset) | |
| os.makedirs(out_dir, exist_ok=True) | |
| finetune_args = TrainingArguments( | |
| output_dir=os.path.join(out_dir, "output"), | |
| overwrite_output_dir=True, | |
| learning_rate=learning_rate, | |
| num_train_epochs=num_epochs, | |
| do_eval=True, | |
| evaluation_strategy="epoch", | |
| per_device_train_batch_size=batch_size, | |
| per_device_eval_batch_size=batch_size, | |
| dataloader_num_workers=8, | |
| report_to="none", | |
| save_strategy="epoch", | |
| logging_strategy="epoch", | |
| save_total_limit=1, | |
| logging_dir=os.path.join(out_dir, "logs"), | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| greater_is_better=False, | |
| seed=SEED, | |
| ) | |
| early_stopping_callback = EarlyStoppingCallback( | |
| early_stopping_patience=10, | |
| early_stopping_threshold=1e-5, | |
| ) | |
| tracking_callback = TrackingCallback() | |
| finetune_trainer = Trainer( | |
| model=finetune_forecast_model, | |
| args=finetune_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=valid_dataset, | |
| callbacks=[early_stopping_callback, tracking_callback], | |
| optimizers=(optimizer, scheduler), | |
| ) | |
| finetune_trainer.remove_callback(INTEGRATION_TO_CALLBACK["codecarbon"]) | |
| st.write("Starting channel-mix finetuning...") | |
| finetune_trainer.train() | |
| st.write("Evaluating finetuned model on test set...") | |
| eval_output = finetune_trainer.evaluate(test_dataset) | |
| st.write("Few-shot (channel-mix) evaluation metrics:") | |
| st.json(eval_output) | |
| # Plot predictions | |
| plot_dir = os.path.join(out_dir, "channel_mix_plots") | |
| os.makedirs(plot_dir, exist_ok=True) | |
| try: | |
| plot_predictions( | |
| model=finetune_trainer.model, | |
| dset=test_dataset, | |
| plot_dir=plot_dir, | |
| plot_prefix="test_channel_mix", | |
| indices=[0], | |
| channel=0, | |
| ) | |
| except Exception as e: | |
| st.error("Error plotting channel mix predictions: " + str(e)) | |
| return | |
| for file in os.listdir(plot_dir): | |
| if file.endswith(".png"): | |
| st.image(os.path.join(plot_dir, file), caption=file) | |
| # -------------------------- | |
| # Mode 3: M4 Hourly Example | |
| def run_m4_hourly_example(): | |
| st.write("## M4 Hourly Example") | |
| st.info("This example reproduces a simplified version of the M4 hourly evaluation.") | |
| # For demonstration, we attempt to load an M4 hourly dataset from a URL. | |
| # (In practice, you would need to download and prepare the dataset.) | |
| M4_DATASET_URL = "https://raw.githubusercontent.com/IBM/TSFM-public/main/tsfm_public/notebooks/ETTh1.csv" # Placeholder URL | |
| try: | |
| m4_data = pd.read_csv(M4_DATASET_URL, parse_dates=["date"]) | |
| except Exception as e: | |
| st.error("Could not load M4 hourly dataset: " + str(e)) | |
| return | |
| st.write("### M4 Hourly Data Preview") | |
| st.dataframe(m4_data.head()) | |
| context_length = 512 | |
| forecast_length = 48 # M4 hourly forecast horizon | |
| timestamp_column = "date" | |
| id_columns = [] | |
| target_columns = [col for col in m4_data.columns if col != timestamp_column] | |
| n = len(m4_data) | |
| split_config = { | |
| "train": [0, int(n * 0.7)], | |
| "valid": [int(n * 0.7), int(n * 0.85)], | |
| "test": [int(n * 0.85), n], | |
| } | |
| column_specifiers = { | |
| "timestamp_column": timestamp_column, | |
| "id_columns": id_columns, | |
| "target_columns": target_columns, | |
| "control_columns": [], | |
| } | |
| tsp = TimeSeriesPreprocessor( | |
| **column_specifiers, | |
| context_length=context_length, | |
| prediction_length=forecast_length, | |
| scaling=True, | |
| encode_categorical=False, | |
| scaler_type="standard", | |
| ) | |
| dset_train, dset_valid, dset_test = get_datasets(tsp, m4_data, split_config) | |
| st.write("Data split completed.") | |
| # Load model from Hugging Face TTM Model Repository (TTM-V1 for M4) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model = TinyTimeMixerForPrediction.from_pretrained( | |
| "ibm-granite/granite-timeseries-ttm-v1", | |
| revision="main", | |
| prediction_filter_length=forecast_length, | |
| ).to(device) | |
| st.write("Running zero-shot evaluation on M4 hourly data...") | |
| temp_dir = tempfile.mkdtemp() | |
| trainer = Trainer( | |
| model=model, | |
| args=TrainingArguments( | |
| output_dir=temp_dir, | |
| per_device_eval_batch_size=64, | |
| report_to="none", | |
| ), | |
| ) | |
| eval_output = trainer.evaluate(dset_test) | |
| st.write("Zero-shot evaluation metrics on M4 hourly:") | |
| st.json(eval_output) | |
| plot_dir = os.path.join(OUT_DIR, "m4_hourly", "zero_shot") | |
| os.makedirs(plot_dir, exist_ok=True) | |
| try: | |
| plot_predictions( | |
| model=trainer.model, | |
| dset=dset_test, | |
| plot_dir=plot_dir, | |
| plot_prefix="m4_zero_shot", | |
| indices=[0], | |
| channel=0, | |
| ) | |
| except Exception as e: | |
| st.error("Error plotting M4 zero-shot predictions: " + str(e)) | |
| return | |
| for file in os.listdir(plot_dir): | |
| if file.endswith(".png"): | |
| st.image(os.path.join(plot_dir, file), caption=file) | |
| st.info("Fine-tuning on M4 hourly data can be added similarly.") | |
| # -------------------------- | |
| # Main UI | |
| def main(): | |
| st.title("Interactive Time-Series Forecasting Dashboard") | |
| st.markdown( | |
| """ | |
| This dashboard lets you run advanced forecasting experiments using the Granite-TimeSeries-TTM model. | |
| Select one of the modes below: | |
| - **Zero-shot Evaluation** | |
| - **Channel-Mix Finetuning Example** | |
| - **M4 Hourly Example** | |
| """ | |
| ) | |
| mode = st.selectbox( | |
| "Select Evaluation Mode", | |
| options=[ | |
| "Zero-shot Evaluation", | |
| "Channel-Mix Finetuning Example", | |
| "M4 Hourly Example", | |
| ], | |
| ) | |
| if mode == "Zero-shot Evaluation": | |
| # Allow user to choose dataset source | |
| dataset_source = st.radio( | |
| "Dataset Source", options=["Default (ETTh1)", "Upload CSV"] | |
| ) | |
| if dataset_source == "Default (ETTh1)": | |
| DATASET_PATH = "https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/ETTh1.csv" | |
| try: | |
| data = pd.read_csv(DATASET_PATH, parse_dates=["date"]) | |
| except Exception as e: | |
| st.error("Error loading default dataset.") | |
| return | |
| st.write("### Default Dataset Preview") | |
| st.dataframe(data.head()) | |
| selected_target_columns = [ | |
| "HUFL", | |
| "HULL", | |
| "MUFL", | |
| "MULL", | |
| "LUFL", | |
| "LULL", | |
| "OT", | |
| ] | |
| else: | |
| uploaded_file = st.file_uploader("Upload your CSV file", type=["csv"]) | |
| if not uploaded_file: | |
| st.info("Awaiting CSV file upload.") | |
| return | |
| data = pd.read_csv(uploaded_file, parse_dates=["date"]) | |
| st.write("### Uploaded Data Preview") | |
| st.dataframe(data.head()) | |
| available_columns = [col for col in data.columns if col != "date"] | |
| selected_target_columns = st.multiselect( | |
| "Select Target Column(s)", | |
| options=available_columns, | |
| default=available_columns, | |
| ) | |
| # Advanced options | |
| available_exog = [ | |
| col | |
| for col in data.columns | |
| if col not in (["date"] + selected_target_columns) | |
| ] | |
| selected_conditional_columns = st.multiselect( | |
| "Select Exogenous/Control Columns", options=available_exog, default=[] | |
| ) | |
| rolling_extension = st.number_input( | |
| "Rolling Forecast Extension (Extra Steps)", value=0, min_value=0, step=1 | |
| ) | |
| forecast_index = st.slider( | |
| "Select Forecast Index for Plotting", | |
| min_value=0, | |
| max_value=len(data) - 1, | |
| value=0, | |
| ) | |
| context_length = st.number_input( | |
| "Context Length", value=DEFAULT_CONTEXT_LENGTH, step=64 | |
| ) | |
| prediction_length = st.number_input( | |
| "Prediction Length", value=DEFAULT_PREDICTION_LENGTH, step=1 | |
| ) | |
| batch_size = st.number_input("Batch Size", value=64, step=1) | |
| if st.button("Run Zero-shot Evaluation"): | |
| with st.spinner("Running zero-shot evaluation..."): | |
| run_zero_shot_forecasting( | |
| data, | |
| context_length, | |
| prediction_length, | |
| batch_size, | |
| selected_target_columns, | |
| selected_conditional_columns, | |
| rolling_extension, | |
| forecast_index, | |
| ) | |
| elif mode == "Channel-Mix Finetuning Example": | |
| if st.button("Run Channel-Mix Finetuning Example"): | |
| with st.spinner("Running channel-mix finetuning..."): | |
| run_channel_mix_finetuning() | |
| elif mode == "M4 Hourly Example": | |
| if st.button("Run M4 Hourly Example"): | |
| with st.spinner("Running M4 hourly example..."): | |
| run_m4_hourly_example() | |
| if __name__ == "__main__": | |
| main() | |