|
import h5py |
|
import torch |
|
import torch.utils.data as data |
|
import torch.multiprocessing |
|
import scipy.io as sio |
|
from torch.nn import functional as F |
|
|
|
|
|
class H5Dataset(data.Dataset): |
|
def __init__(self, H5Path): |
|
super(H5Dataset, self).__init__() |
|
self.H5File = h5py.File(H5Path,'r') |
|
self.LeftData = self.H5File['LeftData'] |
|
self.RightData = self.H5File['RightData'] |
|
|
|
|
|
|
|
def __getitem__(self, index): |
|
return (torch.from_numpy(self.LeftData[index,:,:,:]).float(), |
|
torch.from_numpy(self.RightData[index,:,:,:]).float()) |
|
|
|
def __len__(self): |
|
return self.LeftData.shape[0] |
|
|
|
def save_image_mat(img_r, img_l, result_path, idx): |
|
save_data = {} |
|
save_data['recon_L'] = img_l.detach().cpu().numpy() |
|
save_data['recon_R'] = img_r.detach().cpu().numpy() |
|
sio.savemat(result_path+'img{}.mat'.format(idx), save_data) |
|
|
|
def load_dataset(data_path, batch_size): |
|
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {} |
|
train_dir = data_path + '_train.h5' |
|
val_dir = data_path + '_val.h5' |
|
train_set = H5Dataset(train_dir) |
|
val_set = H5Dataset(val_dir) |
|
train_loader = torch.utils.data.DataLoader(train_set,batch_size=batch_size, shuffle=False, **kwargs) |
|
val_loader = torch.utils.data.DataLoader(val_set,batch_size=batch_size, shuffle=False, **kwargs) |
|
return train_loader, val_loader |
|
|
|
def load_dataset_test(data_path, batch_size): |
|
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {} |
|
test_dir = data_path + '.h5' |
|
test_set = H5Dataset(test_dir) |
|
test_loader = torch.utils.data.DataLoader(test_set,batch_size=batch_size, shuffle=False, **kwargs) |
|
return test_loader |
|
|
|
|
|
def loss_function(xL, xR, x_recon_L, x_recon_R, mu, logvar, beta, left_mask, right_mask): |
|
|
|
Image_Size=xL.size(3) |
|
|
|
beta/=Image_Size**2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
valid_mask_L = xL!=0 |
|
valid_mask_R = xR!=0 |
|
|
|
if left_mask is not None: |
|
valid_mask_L = valid_mask_L & (left_mask.detach().to(torch.int32)==1) |
|
valid_mask_R = valid_mask_R & (right_mask.detach().to(torch.int32)==1) |
|
|
|
MSE_L = F.mse_loss(x_recon_L*valid_mask_L, xL*valid_mask_L, size_average=True) |
|
MSE_R = F.mse_loss(x_recon_R*valid_mask_R, xR *valid_mask_R, size_average=True) |
|
|
|
|
|
KLD = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(1).mean() |
|
|
|
return KLD * beta + MSE_L + MSE_R |
|
|
|
|