File size: 3,214 Bytes
760c94e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import h5py
import torch
import torch.utils.data as data
import torch.multiprocessing
import scipy.io as sio
from torch.nn import functional as F
# torch.multiprocessing.set_start_method('spawn')
class H5Dataset(data.Dataset):
def __init__(self, H5Path):
super(H5Dataset, self).__init__()
self.H5File = h5py.File(H5Path,'r')
self.LeftData = self.H5File['LeftData']
self.RightData = self.H5File['RightData']
#self.LeftMask = self.H5File['LeftMask'][:] # update 2024.01.11 Masks loaded separately
#self.RightMask = self.H5File['RightMask'][:]
def __getitem__(self, index):
return (torch.from_numpy(self.LeftData[index,:,:,:]).float(),
torch.from_numpy(self.RightData[index,:,:,:]).float())
def __len__(self):
return self.LeftData.shape[0]
def save_image_mat(img_r, img_l, result_path, idx):
save_data = {}
save_data['recon_L'] = img_l.detach().cpu().numpy()
save_data['recon_R'] = img_r.detach().cpu().numpy()
sio.savemat(result_path+'img{}.mat'.format(idx), save_data)
def load_dataset(data_path, batch_size):
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
train_dir = data_path + '_train.h5'
val_dir = data_path + '_val.h5'
train_set = H5Dataset(train_dir)
val_set = H5Dataset(val_dir)
train_loader = torch.utils.data.DataLoader(train_set,batch_size=batch_size, shuffle=False, **kwargs)
val_loader = torch.utils.data.DataLoader(val_set,batch_size=batch_size, shuffle=False, **kwargs)
return train_loader, val_loader
def load_dataset_test(data_path, batch_size):
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
test_dir = data_path + '.h5'
test_set = H5Dataset(test_dir)
test_loader = torch.utils.data.DataLoader(test_set,batch_size=batch_size, shuffle=False, **kwargs)
return test_loader
# loss function # update 20240109 mask out zeros
def loss_function(xL, xR, x_recon_L, x_recon_R, mu, logvar, beta, left_mask, right_mask):
Image_Size=xL.size(3)
beta/=Image_Size**2
# print('====> Image_Size: {} Beta: {:.8f}'.format(Image_Size, beta))
# R_batch_size=xR.size(0)
# Tutorial on VAE Page-14
# log[P(X|z)] = C - \frac{1}{2} ||X-f(z)||^2 // \sigma^2
# = C - \frac{1}{2} \sum_{i=1}^{N} ||X^{(i)}-f(z^{(i)}||^2 // \sigma^2
# = C - \farc{1}{2} N * F.mse_loss(Xhat-Xtrue) // \sigma^2
# log[P(X|z)]-C = - \frac{1}{2}*2*192*192//\sigma^2 * F.mse_loss
# Therefore, vae_beta = \frac{1}{36864//\sigma^2}
# mask out zeros
valid_mask_L = xL!=0
valid_mask_R = xR!=0
if left_mask is not None:
valid_mask_L = valid_mask_L & (left_mask.detach().to(torch.int32)==1)
valid_mask_R = valid_mask_R & (right_mask.detach().to(torch.int32)==1)
MSE_L = F.mse_loss(x_recon_L*valid_mask_L, xL*valid_mask_L, size_average=True)
MSE_R = F.mse_loss(x_recon_R*valid_mask_R, xR *valid_mask_R, size_average=True)
# KLD is averaged across batch-samples
KLD = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(1).mean()
return KLD * beta + MSE_L + MSE_R
|