import pickle import random import numpy as np import torch from torch.utils.data import Dataset from config import Config class QlibDataset(Dataset): """ A PyTorch Dataset for handling Qlib financial time series data. This dataset pre-computes all possible start indices for sliding windows and then randomly samples from them during training/validation. Args: data_type (str): The type of dataset to load, either 'train' or 'val'. Raises: ValueError: If `data_type` is not 'train' or 'val'. """ def __init__(self, data_type: str = 'train'): self.config = Config() if data_type not in ['train', 'val']: raise ValueError("data_type must be 'train' or 'val'") self.data_type = data_type # Use a dedicated random number generator for sampling to avoid # interfering with other random processes (e.g., in model initialization). self.py_rng = random.Random(self.config.seed) # Set paths and number of samples based on the data type. if data_type == 'train': self.data_path = f"{self.config.dataset_path}/train_data.pkl" self.n_samples = self.config.n_train_iter else: self.data_path = f"{self.config.dataset_path}/val_data.pkl" self.n_samples = self.config.n_val_iter with open(self.data_path, 'rb') as f: self.data = pickle.load(f) self.window = self.config.lookback_window + self.config.predict_window + 1 self.symbols = list(self.data.keys()) self.feature_list = self.config.feature_list self.time_feature_list = self.config.time_feature_list # Pre-compute all possible (symbol, start_index) pairs. self.indices = [] print(f"[{data_type.upper()}] Pre-computing sample indices...") for symbol in self.symbols: df = self.data[symbol].reset_index() series_len = len(df) num_samples = series_len - self.window + 1 if num_samples > 0: # Generate time features and store them directly in the dataframe. df['minute'] = df['datetime'].dt.minute df['hour'] = df['datetime'].dt.hour df['weekday'] = df['datetime'].dt.weekday df['day'] = df['datetime'].dt.day df['month'] = df['datetime'].dt.month # Keep only necessary columns to save memory. self.data[symbol] = df[self.feature_list + self.time_feature_list] # Add all valid starting indices for this symbol to the global list. for i in range(num_samples): self.indices.append((symbol, i)) # The effective dataset size is the minimum of the configured iterations # and the total number of available samples. self.n_samples = min(self.n_samples, len(self.indices)) print(f"[{data_type.upper()}] Found {len(self.indices)} possible samples. Using {self.n_samples} per epoch.") def set_epoch_seed(self, epoch: int): """ Sets a new seed for the random sampler for each epoch. This is crucial for reproducibility in distributed training. Args: epoch (int): The current epoch number. """ epoch_seed = self.config.seed + epoch self.py_rng.seed(epoch_seed) def __len__(self) -> int: """Returns the number of samples per epoch.""" return self.n_samples def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: """ Retrieves a random sample from the dataset. Note: The `idx` argument is ignored. Instead, a random index is drawn from the pre-computed `self.indices` list using `self.py_rng`. This ensures random sampling over the entire dataset for each call. Args: idx (int): Ignored. Returns: tuple[torch.Tensor, torch.Tensor]: A tuple containing: - x_tensor (torch.Tensor): The normalized feature tensor. - x_stamp_tensor (torch.Tensor): The time feature tensor. """ # Select a random sample from the entire pool of indices. random_idx = self.py_rng.randint(0, len(self.indices) - 1) symbol, start_idx = self.indices[random_idx] # Extract the sliding window from the dataframe. df = self.data[symbol] end_idx = start_idx + self.window win_df = df.iloc[start_idx:end_idx] # Separate main features and time features. x = win_df[self.feature_list].values.astype(np.float32) x_stamp = win_df[self.time_feature_list].values.astype(np.float32) # Perform instance-level normalization. x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) x = (x - x_mean) / (x_std + 1e-5) x = np.clip(x, -self.config.clip, self.config.clip) # Convert to PyTorch tensors. x_tensor = torch.from_numpy(x) x_stamp_tensor = torch.from_numpy(x_stamp) return x_tensor, x_stamp_tensor if __name__ == '__main__': # Example usage and verification. print("Creating training dataset instance...") train_dataset = QlibDataset(data_type='train') print(f"Dataset length: {len(train_dataset)}") if len(train_dataset) > 0: try_x, try_x_stamp = train_dataset[100] # Index 100 is ignored. print(f"Sample feature shape: {try_x.shape}") print(f"Sample time feature shape: {try_x_stamp.shape}") else: print("Dataset is empty.")