crypt / finetune /dataset.py
heyunfei's picture
Upload 56 files
85653bc verified
raw
history blame
5.75 kB
import pickle
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from config import Config
class QlibDataset(Dataset):
"""
A PyTorch Dataset for handling Qlib financial time series data.
This dataset pre-computes all possible start indices for sliding windows
and then randomly samples from them during training/validation.
Args:
data_type (str): The type of dataset to load, either 'train' or 'val'.
Raises:
ValueError: If `data_type` is not 'train' or 'val'.
"""
def __init__(self, data_type: str = 'train'):
self.config = Config()
if data_type not in ['train', 'val']:
raise ValueError("data_type must be 'train' or 'val'")
self.data_type = data_type
# Use a dedicated random number generator for sampling to avoid
# interfering with other random processes (e.g., in model initialization).
self.py_rng = random.Random(self.config.seed)
# Set paths and number of samples based on the data type.
if data_type == 'train':
self.data_path = f"{self.config.dataset_path}/train_data.pkl"
self.n_samples = self.config.n_train_iter
else:
self.data_path = f"{self.config.dataset_path}/val_data.pkl"
self.n_samples = self.config.n_val_iter
with open(self.data_path, 'rb') as f:
self.data = pickle.load(f)
self.window = self.config.lookback_window + self.config.predict_window + 1
self.symbols = list(self.data.keys())
self.feature_list = self.config.feature_list
self.time_feature_list = self.config.time_feature_list
# Pre-compute all possible (symbol, start_index) pairs.
self.indices = []
print(f"[{data_type.upper()}] Pre-computing sample indices...")
for symbol in self.symbols:
df = self.data[symbol].reset_index()
series_len = len(df)
num_samples = series_len - self.window + 1
if num_samples > 0:
# Generate time features and store them directly in the dataframe.
df['minute'] = df['datetime'].dt.minute
df['hour'] = df['datetime'].dt.hour
df['weekday'] = df['datetime'].dt.weekday
df['day'] = df['datetime'].dt.day
df['month'] = df['datetime'].dt.month
# Keep only necessary columns to save memory.
self.data[symbol] = df[self.feature_list + self.time_feature_list]
# Add all valid starting indices for this symbol to the global list.
for i in range(num_samples):
self.indices.append((symbol, i))
# The effective dataset size is the minimum of the configured iterations
# and the total number of available samples.
self.n_samples = min(self.n_samples, len(self.indices))
print(f"[{data_type.upper()}] Found {len(self.indices)} possible samples. Using {self.n_samples} per epoch.")
def set_epoch_seed(self, epoch: int):
"""
Sets a new seed for the random sampler for each epoch. This is crucial
for reproducibility in distributed training.
Args:
epoch (int): The current epoch number.
"""
epoch_seed = self.config.seed + epoch
self.py_rng.seed(epoch_seed)
def __len__(self) -> int:
"""Returns the number of samples per epoch."""
return self.n_samples
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
"""
Retrieves a random sample from the dataset.
Note: The `idx` argument is ignored. Instead, a random index is drawn
from the pre-computed `self.indices` list using `self.py_rng`. This
ensures random sampling over the entire dataset for each call.
Args:
idx (int): Ignored.
Returns:
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- x_tensor (torch.Tensor): The normalized feature tensor.
- x_stamp_tensor (torch.Tensor): The time feature tensor.
"""
# Select a random sample from the entire pool of indices.
random_idx = self.py_rng.randint(0, len(self.indices) - 1)
symbol, start_idx = self.indices[random_idx]
# Extract the sliding window from the dataframe.
df = self.data[symbol]
end_idx = start_idx + self.window
win_df = df.iloc[start_idx:end_idx]
# Separate main features and time features.
x = win_df[self.feature_list].values.astype(np.float32)
x_stamp = win_df[self.time_feature_list].values.astype(np.float32)
# Perform instance-level normalization.
x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
x = (x - x_mean) / (x_std + 1e-5)
x = np.clip(x, -self.config.clip, self.config.clip)
# Convert to PyTorch tensors.
x_tensor = torch.from_numpy(x)
x_stamp_tensor = torch.from_numpy(x_stamp)
return x_tensor, x_stamp_tensor
if __name__ == '__main__':
# Example usage and verification.
print("Creating training dataset instance...")
train_dataset = QlibDataset(data_type='train')
print(f"Dataset length: {len(train_dataset)}")
if len(train_dataset) > 0:
try_x, try_x_stamp = train_dataset[100] # Index 100 is ignored.
print(f"Sample feature shape: {try_x.shape}")
print(f"Sample time feature shape: {try_x_stamp.shape}")
else:
print("Dataset is empty.")