crypt / finetune /dataset.py
heyunfei's picture
Upload 56 files
85653bc verified
import pickle
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from config import Config
class QlibDataset(Dataset):
"""
A PyTorch Dataset for handling Qlib financial time series data.
This dataset pre-computes all possible start indices for sliding windows
and then randomly samples from them during training/validation.
Args:
data_type (str): The type of dataset to load, either 'train' or 'val'.
Raises:
ValueError: If `data_type` is not 'train' or 'val'.
"""
def __init__(self, data_type: str = 'train'):
self.config = Config()
if data_type not in ['train', 'val']:
raise ValueError("data_type must be 'train' or 'val'")
self.data_type = data_type
# Use a dedicated random number generator for sampling to avoid
# interfering with other random processes (e.g., in model initialization).
self.py_rng = random.Random(self.config.seed)
# Set paths and number of samples based on the data type.
if data_type == 'train':
self.data_path = f"{self.config.dataset_path}/train_data.pkl"
self.n_samples = self.config.n_train_iter
else:
self.data_path = f"{self.config.dataset_path}/val_data.pkl"
self.n_samples = self.config.n_val_iter
with open(self.data_path, 'rb') as f:
self.data = pickle.load(f)
self.window = self.config.lookback_window + self.config.predict_window + 1
self.symbols = list(self.data.keys())
self.feature_list = self.config.feature_list
self.time_feature_list = self.config.time_feature_list
# Pre-compute all possible (symbol, start_index) pairs.
self.indices = []
print(f"[{data_type.upper()}] Pre-computing sample indices...")
for symbol in self.symbols:
df = self.data[symbol].reset_index()
series_len = len(df)
num_samples = series_len - self.window + 1
if num_samples > 0:
# Generate time features and store them directly in the dataframe.
df['minute'] = df['datetime'].dt.minute
df['hour'] = df['datetime'].dt.hour
df['weekday'] = df['datetime'].dt.weekday
df['day'] = df['datetime'].dt.day
df['month'] = df['datetime'].dt.month
# Keep only necessary columns to save memory.
self.data[symbol] = df[self.feature_list + self.time_feature_list]
# Add all valid starting indices for this symbol to the global list.
for i in range(num_samples):
self.indices.append((symbol, i))
# The effective dataset size is the minimum of the configured iterations
# and the total number of available samples.
self.n_samples = min(self.n_samples, len(self.indices))
print(f"[{data_type.upper()}] Found {len(self.indices)} possible samples. Using {self.n_samples} per epoch.")
def set_epoch_seed(self, epoch: int):
"""
Sets a new seed for the random sampler for each epoch. This is crucial
for reproducibility in distributed training.
Args:
epoch (int): The current epoch number.
"""
epoch_seed = self.config.seed + epoch
self.py_rng.seed(epoch_seed)
def __len__(self) -> int:
"""Returns the number of samples per epoch."""
return self.n_samples
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Retrieves a random sample from the dataset.
Note: The `idx` argument is ignored. Instead, a random index is drawn
from the pre-computed `self.indices` list using `self.py_rng`. This
ensures random sampling over the entire dataset for each call.
Args:
idx (int): Ignored.
Returns:
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- x_tensor (torch.Tensor): The normalized feature tensor.
- x_stamp_tensor (torch.Tensor): The time feature tensor.
"""
# Select a random sample from the entire pool of indices.
random_idx = self.py_rng.randint(0, len(self.indices) - 1)
symbol, start_idx = self.indices[random_idx]
# Extract the sliding window from the dataframe.
df = self.data[symbol]
end_idx = start_idx + self.window
win_df = df.iloc[start_idx:end_idx]
# Separate main features and time features.
x = win_df[self.feature_list].values.astype(np.float32)
x_stamp = win_df[self.time_feature_list].values.astype(np.float32)
# Perform instance-level normalization.
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
x = (x - x_mean) / (x_std + 1e-5)
x = np.clip(x, -self.config.clip, self.config.clip)
# Convert to PyTorch tensors.
x_tensor = torch.from_numpy(x)
x_stamp_tensor = torch.from_numpy(x_stamp)
return x_tensor, x_stamp_tensor
if __name__ == '__main__':
# Example usage and verification.
print("Creating training dataset instance...")
train_dataset = QlibDataset(data_type='train')
print(f"Dataset length: {len(train_dataset)}")
if len(train_dataset) > 0:
try_x, try_x_stamp = train_dataset[100] # Index 100 is ignored.
print(f"Sample feature shape: {try_x.shape}")
print(f"Sample time feature shape: {try_x_stamp.shape}")
else:
print("Dataset is empty.")