|
class CustomDataset(Dataset): |
|
def __init__( |
|
self, |
|
df : pd.DataFrame, |
|
augment : bool = False, |
|
mode : str = 'train', |
|
specs : Dict[int, np.ndarray] = spectrograms, |
|
eeg_specs: Dict[int, np.ndarray] = all_eegs |
|
): |
|
self.df = df |
|
self.augment = augment |
|
self.mode = mode |
|
self.spectograms = spectrograms |
|
self.eeg_spectograms = eeg_specs |
|
|
|
def __len__(self): |
|
""" |
|
Denotes the number of batches per epoch. |
|
""" |
|
return len(self.df) |
|
|
|
def __getitem__(self, index): |
|
""" |
|
Generate one batch of data. |
|
""" |
|
X, y = self.__data_generation(index) |
|
if self.augment: |
|
X = self.__transform(X) |
|
return {"spectrogram":torch.tensor(X, dtype=torch.float32), "labels":torch.tensor(y, dtype=torch.float32)} |
|
|
|
def __data_generation(self, index): |
|
""" |
|
Generates data containing batch_size samples. |
|
""" |
|
X = np.zeros((128, 256, 8), dtype='float32') |
|
y = np.zeros(6, dtype='float32') |
|
img = np.ones((128,256), dtype='float32') |
|
row = self.df.iloc[index] |
|
if self.mode=='test': |
|
r = 0 |
|
else: |
|
r = int(row['spectrogram_label_offset_seconds'] // 2) |
|
|
|
for region in range(4): |
|
img = self.spectograms[row.spectrogram_id][r:r+300, region*100:(region+1)*100].T |
|
|
|
|
|
img = np.clip(img, np.exp(-4), np.exp(8)) |
|
img = np.log(img) |
|
|
|
|
|
ep = 1e-6 |
|
mu = np.nanmean(img.flatten()) |
|
std = np.nanstd(img.flatten()) |
|
img = (img-mu)/(std+ep) |
|
img = np.nan_to_num(img, nan=0.0) |
|
X[14:-14, :, region] = img[:, 22:-22] / 2.0 |
|
img = self.eeg_spectograms[row.label_id] |
|
X[:, :, 4:] = img |
|
|
|
if self.mode != 'test': |
|
y = row[TARGETS].values.astype(np.float32) |
|
|
|
return X, y |
|
|
|
def __transform(self, img): |
|
params1 = { |
|
"num_masks_x" : 1, |
|
"mask_x_length": (0, 20), |
|
"fill_value" : (0, 1, 2, 3, 4, 5, 6, 7), |
|
} |
|
params2 = { |
|
"num_masks_y" : 1, |
|
"mask_y_length": (0, 20), |
|
"fill_value" : (0, 1, 2, 3, 4, 5, 6, 7), |
|
} |
|
params3 = { |
|
"num_masks_x" : (2, 4), |
|
"num_masks_y" : 5, |
|
"mask_y_length": 8, |
|
"mask_x_length": (10, 20), |
|
"fill_value" : (0, 1, 2, 3, 4, 5, 6, 7), |
|
} |
|
|
|
transforms = A.Compose([ |
|
A.XYMasking(**params1, p=0.3), |
|
A.XYMasking(**params2, p=0.3), |
|
A.XYMasking(**params3, p=0.3), |
|
]) |
|
return transforms(image=img)['image'] |
|
|