LeonardoBerti's picture
Upload 51 files
69524d0 verified
import torch
from torch.utils import data
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import numpy as np
import constants as cst
import time
from torch.utils import data
from utils.utils_data import one_hot_encoding_type, tanh_encoding_type
class Dataset(data.Dataset):
"""Characterizes a dataset for PyTorch"""
def __init__(self, x, y, seq_size):
"""Initialization"""
self.seq_size = seq_size
self.length = y.shape[0]
self.x = x
self.y = y
if type(self.x) == np.ndarray:
self.x = torch.from_numpy(x).float()
if type(self.y) == np.ndarray:
self.y = torch.from_numpy(y).long()
self.data = self.x
def __len__(self):
"""Denotes the total number of samples"""
return self.length
def __getitem__(self, i):
input = self.x[i:i+self.seq_size, :]
return input, self.y[i]
class DataModule(pl.LightningDataModule):
def __init__(self, train_set, val_set, batch_size, test_batch_size, is_shuffle_train=True, test_set=None, num_workers=16):
super().__init__()
self.train_set = train_set
self.val_set = val_set
self.test_set = test_set
self.batch_size = batch_size
self.test_batch_size = test_batch_size
self.is_shuffle_train = is_shuffle_train
if train_set.data.device.type != cst.DEVICE: #this is true only when we are using a GPU but the data is still on the CPU
self.pin_memory = True
else:
self.pin_memory = False
self.num_workers = num_workers
def train_dataloader(self):
return DataLoader(
dataset=self.train_set,
batch_size=self.batch_size,
shuffle=self.is_shuffle_train,
pin_memory=self.pin_memory,
drop_last=False,
num_workers=self.num_workers,
persistent_workers=True
)
def val_dataloader(self):
return DataLoader(
dataset=self.val_set,
batch_size=self.test_batch_size,
shuffle=False,
pin_memory=self.pin_memory,
drop_last=False,
num_workers=self.num_workers,
persistent_workers=True
)
def test_dataloader(self):
return DataLoader(
dataset=self.test_set,
batch_size=self.test_batch_size,
shuffle=False,
pin_memory=self.pin_memory,
drop_last=False,
num_workers=self.num_workers,
persistent_workers=True
)