fcMRI-VAE / utils.py
cindyhfls's picture
Basic files for implementing trained model
760c94e verified
raw
history blame
3.21 kB
import h5py
import torch
import torch.utils.data as data
import torch.multiprocessing
import scipy.io as sio
from torch.nn import functional as F
# torch.multiprocessing.set_start_method('spawn')
class H5Dataset(data.Dataset):
def __init__(self, H5Path):
super(H5Dataset, self).__init__()
self.H5File = h5py.File(H5Path,'r')
self.LeftData = self.H5File['LeftData']
self.RightData = self.H5File['RightData']
#self.LeftMask = self.H5File['LeftMask'][:] # update 2024.01.11 Masks loaded separately
#self.RightMask = self.H5File['RightMask'][:]
def __getitem__(self, index):
return (torch.from_numpy(self.LeftData[index,:,:,:]).float(),
torch.from_numpy(self.RightData[index,:,:,:]).float())
def __len__(self):
return self.LeftData.shape[0]
def save_image_mat(img_r, img_l, result_path, idx):
save_data = {}
save_data['recon_L'] = img_l.detach().cpu().numpy()
save_data['recon_R'] = img_r.detach().cpu().numpy()
sio.savemat(result_path+'img{}.mat'.format(idx), save_data)
def load_dataset(data_path, batch_size):
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
train_dir = data_path + '_train.h5'
val_dir = data_path + '_val.h5'
train_set = H5Dataset(train_dir)
val_set = H5Dataset(val_dir)
train_loader = torch.utils.data.DataLoader(train_set,batch_size=batch_size, shuffle=False, **kwargs)
val_loader = torch.utils.data.DataLoader(val_set,batch_size=batch_size, shuffle=False, **kwargs)
return train_loader, val_loader
def load_dataset_test(data_path, batch_size):
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
test_dir = data_path + '.h5'
test_set = H5Dataset(test_dir)
test_loader = torch.utils.data.DataLoader(test_set,batch_size=batch_size, shuffle=False, **kwargs)
return test_loader
# loss function # update 20240109 mask out zeros
def loss_function(xL, xR, x_recon_L, x_recon_R, mu, logvar, beta, left_mask, right_mask):
Image_Size=xL.size(3)
beta/=Image_Size**2
# print('====> Image_Size: {} Beta: {:.8f}'.format(Image_Size, beta))
# R_batch_size=xR.size(0)
# Tutorial on VAE Page-14
# log[P(X|z)] = C - \frac{1}{2} ||X-f(z)||^2 // \sigma^2
# = C - \frac{1}{2} \sum_{i=1}^{N} ||X^{(i)}-f(z^{(i)}||^2 // \sigma^2
# = C - \farc{1}{2} N * F.mse_loss(Xhat-Xtrue) // \sigma^2
# log[P(X|z)]-C = - \frac{1}{2}*2*192*192//\sigma^2 * F.mse_loss
# Therefore, vae_beta = \frac{1}{36864//\sigma^2}
# mask out zeros
valid_mask_L = xL!=0
valid_mask_R = xR!=0
if left_mask is not None:
valid_mask_L = valid_mask_L & (left_mask.detach().to(torch.int32)==1)
valid_mask_R = valid_mask_R & (right_mask.detach().to(torch.int32)==1)
MSE_L = F.mse_loss(x_recon_L*valid_mask_L, xL*valid_mask_L, size_average=True)
MSE_R = F.mse_loss(x_recon_R*valid_mask_R, xR *valid_mask_R, size_average=True)
# KLD is averaged across batch-samples
KLD = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(1).mean()
return KLD * beta + MSE_L + MSE_R