|
import torch
|
|
from torch.utils import data
|
|
import pytorch_lightning as pl
|
|
from torch.utils.data import DataLoader
|
|
import numpy as np
|
|
import constants as cst
|
|
import time
|
|
from torch.utils import data
|
|
from utils.utils_data import one_hot_encoding_type, tanh_encoding_type
|
|
|
|
class Dataset(data.Dataset):
|
|
"""Characterizes a dataset for PyTorch"""
|
|
def __init__(self, x, y, seq_size):
|
|
"""Initialization"""
|
|
self.seq_size = seq_size
|
|
self.length = y.shape[0]
|
|
self.x = x
|
|
self.y = y
|
|
if type(self.x) == np.ndarray:
|
|
self.x = torch.from_numpy(x).float()
|
|
if type(self.y) == np.ndarray:
|
|
self.y = torch.from_numpy(y).long()
|
|
self.data = self.x
|
|
|
|
def __len__(self):
|
|
"""Denotes the total number of samples"""
|
|
return self.length
|
|
|
|
def __getitem__(self, i):
|
|
input = self.x[i:i+self.seq_size, :]
|
|
return input, self.y[i]
|
|
|
|
|
|
|
|
|
|
|
|
class DataModule(pl.LightningDataModule):
|
|
def __init__(self, train_set, val_set, batch_size, test_batch_size, is_shuffle_train=True, test_set=None, num_workers=16):
|
|
super().__init__()
|
|
|
|
self.train_set = train_set
|
|
self.val_set = val_set
|
|
self.test_set = test_set
|
|
self.batch_size = batch_size
|
|
self.test_batch_size = test_batch_size
|
|
self.is_shuffle_train = is_shuffle_train
|
|
if train_set.data.device.type != cst.DEVICE:
|
|
self.pin_memory = True
|
|
else:
|
|
self.pin_memory = False
|
|
self.num_workers = num_workers
|
|
|
|
def train_dataloader(self):
|
|
return DataLoader(
|
|
dataset=self.train_set,
|
|
batch_size=self.batch_size,
|
|
shuffle=self.is_shuffle_train,
|
|
pin_memory=self.pin_memory,
|
|
drop_last=False,
|
|
num_workers=self.num_workers,
|
|
persistent_workers=True
|
|
)
|
|
|
|
def val_dataloader(self):
|
|
return DataLoader(
|
|
dataset=self.val_set,
|
|
batch_size=self.test_batch_size,
|
|
shuffle=False,
|
|
pin_memory=self.pin_memory,
|
|
drop_last=False,
|
|
num_workers=self.num_workers,
|
|
persistent_workers=True
|
|
)
|
|
|
|
def test_dataloader(self):
|
|
return DataLoader(
|
|
dataset=self.test_set,
|
|
batch_size=self.test_batch_size,
|
|
shuffle=False,
|
|
pin_memory=self.pin_memory,
|
|
drop_last=False,
|
|
num_workers=self.num_workers,
|
|
persistent_workers=True
|
|
)
|
|
|
|
|
|
|