cindyhfls commited on
Commit
760c94e
·
verified ·
1 Parent(s): 27caee2

Basic files for implementing trained model

Browse files
Checkpoint/checkpoint49_2024-03-28_Zdim_2_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81de496d81529e84942d3ca6652e578ad412712044e472bcd01635e0533609f3
3
+ size 48127946
Checkpoint/checkpoint49_2024-06-21_Zdim_4_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93e6aac3dc3eeb3296ce2e24a480e5b704b85beb32360a31953409fb7dfc00ac
3
+ size 48791690
Checkpoint/checkpoint49_2024-11-28_Zdim_3_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f48320b2e5a4ae6d668e22284b50470d8444dd99cd82427ad0ccead34fdd9bbd
3
+ size 48459722
VAE_inference_example.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch # tested on version 2.1.2+cu118
2
+ import scipy.io as io
3
+ import argparse
4
+ import logging
5
+ from utils import load_dataset_test, save_image_mat
6
+ from fMRIVAE_Model import BetaVAE
7
+ import os
8
+
9
+ def main():
10
+
11
+ parser = argparse.ArgumentParser(description='VAE for fMRI generation')
12
+ parser.add_argument('--batch-size', type=int, metavar='N', help='how many samples per saved file?')
13
+ parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
14
+ parser.add_argument('--zdim', type=int, default=256, metavar='N', help='dimension of latent variables')
15
+ parser.add_argument('--data-path', type=str, metavar='DIR', help='path to dataset')
16
+ parser.add_argument('--z-path', type=str, default='./result/latent/', help='path to saved z files')
17
+ parser.add_argument('--resume', type=str, default='./checkpoint/checkpoint.pth.tar', help='the VAE checkpoint')
18
+ parser.add_argument('--img-path', type=str, default='./result/recon', help='path to save reconstructed images')
19
+ parser.add_argument('--mode', type=str, default='both', help='choose from \'encode\',\'decode\' or \'both\'')
20
+ parser.add_argument('--debug', action='store_true', help='Enable debug mode for detailed logging')
21
+
22
+ args = parser.parse_args()
23
+
24
+ if not os.path.isdir(args.z_path):
25
+ os.system('mkdir '+ args.z_path + ' -p')
26
+ if (args.mode != 'encode') and not os.path.isdir(args.img_path):
27
+ os.system('mkdir '+ args.img_path + ' -p')
28
+
29
+ # Set logging level based on debug flag
30
+ logging_level = logging.DEBUG if args.debug else logging.INFO
31
+ logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
32
+
33
+ logging.debug("Starting the VAE inference script.")
34
+ args = parser.parse_args()
35
+ logging.debug(f"Parsed arguments: {args}")
36
+
37
+ try:
38
+ torch.manual_seed(args.seed)
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ logging.debug(f"Using device: {device}")
41
+
42
+ logging.debug(f"Loading VAE model from {args.resume}.")
43
+ model = BetaVAE(z_dim=args.zdim, nc=1).to(device)
44
+ if os.path.isfile(args.resume):
45
+ checkpoint = torch.load(args.resume, map_location=device)
46
+ model.load_state_dict(checkpoint['state_dict'])
47
+ logging.debug("Checkpoint loaded.")
48
+ else:
49
+ logging.error(f"Checkpoint not found at {args.resume}")
50
+ raise RuntimeError("Checkpoint not found.")
51
+
52
+ if (args.mode == 'encode') or (args.mode == 'both'):
53
+ logging.debug("Starting encoding process...")
54
+ test_loader = load_dataset_test(args.data_path, args.batch_size)
55
+ logging.debug(f"Loaded test dataset from {args.data_path}")
56
+ for batch_idx, (xL, xR) in enumerate(test_loader):
57
+ xL = xL.to(device)
58
+ xR = xR.to(device)
59
+ z_distribution = model._encode(xL, xR)
60
+ save_data = {'z_distribution': z_distribution.detach().cpu().numpy()}
61
+ io.savemat(os.path.join(args.z_path, f'save_z{batch_idx}.mat'), save_data)
62
+ logging.debug(f"Encoded batch {batch_idx}")
63
+
64
+ if (args.mode == 'decode') or (args.mode == 'both'):
65
+ logging.debug("Starting decoding process...")
66
+ filelist = [f for f in os.listdir(args.z_path) if f.split('_')[0] == 'save']
67
+ logging.debug(f"Filelist: {filelist}")
68
+ for batch_idx, filename in enumerate(filelist):
69
+ logging.debug(f"Decoding file {filename}")
70
+ z_dist = io.loadmat(os.path.join(args.z_path, f'save_z{batch_idx}.mat'))
71
+ z_dist = z_dist['z_distribution']
72
+ mu = z_dist[:, :args.zdim]
73
+ z = torch.tensor(mu).to(device)
74
+ x_recon_L, x_recon_R = model._decode(z)
75
+ save_image_mat(x_recon_R, x_recon_L, args.img_path, batch_idx)
76
+ logging.debug(f"Decoded and saved batch {batch_idx}")
77
+
78
+ except Exception as e:
79
+ logging.error(f"An error occurred: {e}")
80
+ raise
81
+
82
+ if __name__ == "__main__":
83
+ main()
call_VAE_inference_example.sh ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ SECONDS=0
3
+
4
+ # MSC
5
+ subjs=(01 02 03 04 05 06 07 08 09 10)
6
+ parcelcount=(602 567 620 616 633 580 628 710 613 649)
7
+ zdim=2
8
+ checkpoint="./VAE_Model/Checkpoint/checkpoint49_2024-03-28_Zdim_2_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar"
9
+ for i in "${!subjs[@]}"; do
10
+ subj="${subjs[$i]}"
11
+ curr_parcel_count="${parcelcount[$i]}"
12
+ echo $curr_parcel_count
13
+ namestr="sub-MSC${subj}_sub-MSC${subj}Parcel"
14
+
15
+ python3 VAE_inference_example.py --data-path ./data/$namestr --zdim $zdim \
16
+ --resume "${checkpoint}" \
17
+ --z-path './result/latent/'$namestr'_Zdim'$zdim --mode 'encode' --batch-size $curr_parcel_count
18
+
19
+ echo "The command took $SECONDS seconds."
20
+ done
fMRIVAE_Model.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """model.py"""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.nn.init as init
7
+ from collections import Iterable
8
+ from torch.autograd import Variable
9
+
10
+ class BetaVAE(nn.Module):
11
+ """Model proposed in original beta-VAE paper(Higgins et al, ICLR, 2017)."""
12
+
13
+ def __init__(self, z_dim=64, nc=1, cirpad_dire=(False, True)):
14
+ super(BetaVAE, self).__init__()
15
+ self.z_dim = z_dim
16
+ self.nc = nc
17
+ self.cirpad_dire = cirpad_dire
18
+
19
+ self.ocs = [64, 128, 128, 256, 256]
20
+ self.nLays = len(self.ocs)
21
+ self.topW = int(192/2**self.nLays)
22
+
23
+ # encoder
24
+ self.ConvL = nn.Conv2d(1,int(self.ocs[0]/2),8,2,0) # pad=3, only in forward
25
+ self.ConvR = nn.Conv2d(1,int(self.ocs[0]/2),8,2,0) # pad=3, only in forward # B, 128, 96, 96
26
+ self.EncConvs = nn.ModuleList([nn.Conv2d(self.ocs[i-1], self.ocs[i], 4, 2, 0) for i in range(1, self.nLays)]) # pad=1 only in forward
27
+ self.fc1 = nn.Linear(self.ocs[-1]*self.topW**2, z_dim*2)
28
+
29
+ # decoder
30
+ self.fc2 = nn.Linear(z_dim, self.ocs[-1]*self.topW**2)
31
+ self.DecConvs = nn.ModuleList([nn.ConvTranspose2d(self.ocs[i], self.ocs[i-1], 4, 2, 3) for i in range(4,0,-1)]) # pad=1; dilation * (kernel_size - 1) - padding = 6 (later in forward)
32
+ self.tConvL = nn.ConvTranspose2d(int(self.ocs[0]/2), nc, 8, 2, 9) # pad=3 later; dilation * (kernel_size - 1) - padding = 4 (later in forward)
33
+ self.tConvR = nn.ConvTranspose2d(int(self.ocs[0]/2), nc, 8, 2, 9) # pad=3 later
34
+
35
+ self.relu = nn.ReLU(inplace=True)
36
+
37
+ self.weight_init()
38
+
39
+ def cirpad(self, x, padding, cirpad_dire):
40
+ # x is input
41
+ # padding is the size of pading
42
+ # cirpad_dire is (last_dim_pad, second_to_last_dim_pad)
43
+
44
+ # >>> t4d = torch.empty(3, 3, 4, 2)
45
+ # >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2)
46
+ # >>> out = F.pad(t4d, p2d, "constant", 0)
47
+ # >>> print(out.size())
48
+ # torch.Size([3, 3, 8, 4])
49
+
50
+ # last dim
51
+ if cirpad_dire[0] is True:
52
+ x = F.pad(x, (padding, padding, 0, 0), 'circular')
53
+ else:
54
+ x = F.pad(x, (padding, padding, 0, 0), "constant", 0)
55
+
56
+ # second last dim
57
+ if cirpad_dire[1] is True:
58
+ x = F.pad(x, (0, 0, padding, padding), 'circular')
59
+ else:
60
+ x = F.pad(x, (0, 0, padding, padding), "constant", 0)
61
+
62
+ return x
63
+
64
+
65
+ def weight_init(self):
66
+ for block in self._modules:
67
+ if isinstance(self._modules[block], Iterable):
68
+ for m in self._modules[block]:
69
+ m.apply(kaiming_init)
70
+ else:
71
+ self._modules[block].apply(kaiming_init)
72
+
73
+ def _encode(self, xL, xR):
74
+ xL = self.cirpad(xL, 3, self.cirpad_dire)
75
+ xR = self.cirpad(xR, 3, self.cirpad_dire)
76
+ x = torch.cat((self.ConvL(xL), self.ConvR(xR)), 1)
77
+ x = self.relu(x)
78
+ for lay in range(self.nLays-1):
79
+ x = self.cirpad(x, 1, self.cirpad_dire)
80
+ x = self.relu(self.EncConvs[lay](x))
81
+ x = x.view(-1, self.ocs[-1]*self.topW*self.topW)
82
+ x = self.fc1(x)
83
+ return x
84
+
85
+ def _decode(self, z):
86
+ x = self.relu(self.fc2(z).view(-1 , self.ocs[-1], self.topW, self.topW))
87
+ for lay in range(self.nLays-1):
88
+ x = self.cirpad(x, 1, self.cirpad_dire)
89
+ x = self.relu(self.DecConvs[lay](x))
90
+
91
+ xL, xR = torch.chunk(x, 2, dim=1)
92
+
93
+ xrL = self.tConvL(self.cirpad(xL, 3, self.cirpad_dire))
94
+
95
+ xrR = self.tConvR(self.cirpad(xR, 3, self.cirpad_dire))
96
+ return xrL, xrR
97
+
98
+ def reparametrize(self, mu, logvar):
99
+ std = logvar.div(2).exp()
100
+ eps = Variable(std.data.new(std.size()).normal_())
101
+ return mu + std*eps
102
+
103
+
104
+ def forward(self, xL, xR):
105
+ distributions = self._encode(xL, xR)
106
+ mu = distributions[:, :self.z_dim]
107
+ logvar = distributions[:, self.z_dim:]
108
+ z = self.reparametrize(mu, logvar)
109
+ x_recon_L, x_recon_R = self._decode(z)
110
+ return x_recon_L, x_recon_R, mu, logvar
111
+
112
+ def kaiming_init(m):
113
+ if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): # Shall we apply init to ConvTranspose2d?
114
+ init.kaiming_normal_(m.weight)
115
+ if m.bias is not None:
116
+ m.bias.data.fill_(0)
117
+ elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
118
+ m.weight.data.fill_(1)
119
+ if m.bias is not None:
120
+ m.bias.data.fill_(0)
121
+
122
+ def normal_init(m, mean, std):
123
+ if isinstance(m, (nn.Linear, nn.Conv2d)):
124
+ m.weight.data.normal_(mean, std)
125
+ if m.bias.data is not None:
126
+ m.bias.data.zero_()
127
+ elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
128
+ m.weight.data.fill_(1)
129
+ if m.bias.data is not None:
130
+ m.bias.data.zero_()
131
+
132
+ #if __name__ == "__main__":
133
+ # m = BetaVAE_H()
134
+ # a=torch.ones(1,1,192,192)
135
+ # out1, out2, _, _ = m(a,a)
136
+ # print(out1.size())
137
+
138
+
139
+
140
+
mask/Left_fMRI2Grid_192_by_192_NN.mat ADDED
Binary file (199 kB). View file
 
mask/MSE_Mask.mat ADDED
Binary file (1.21 kB). View file
 
mask/Right_fMRI2Grid_192_by_192_NN.mat ADDED
Binary file (199 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ asttokens==2.4.1
2
+ backcall==0.2.0
3
+ certifi==2023.11.17
4
+ charset-normalizer==3.3.2
5
+ comm==0.2.0
6
+ contourpy==1.1.1
7
+ cycler==0.12.1
8
+ debugpy==1.8.0
9
+ decorator==5.1.1
10
+ executing==2.0.1
11
+ filelock==3.13.1
12
+ fonttools==4.47.0
13
+ fsspec==2023.12.2
14
+ h5py==3.10.0
15
+ idna==3.6
16
+ importlib-metadata==7.0.0
17
+ importlib-resources==6.1.1
18
+ ipykernel==6.27.1
19
+ ipython==8.12.3
20
+ jedi==0.19.1
21
+ Jinja2==3.1.2
22
+ jupyter_client==8.6.0
23
+ jupyter_core==5.5.1
24
+ kiwisolver==1.4.5
25
+ MarkupSafe==2.1.3
26
+ matplotlib==3.7.4
27
+ matplotlib-inline==0.1.6
28
+ mpmath==1.3.0
29
+ nest-asyncio==1.5.8
30
+ networkx==3.1
31
+ nibabel==5.2.0
32
+ numpy==1.24.4
33
+ nvidia-cublas-cu12==12.1.3.1
34
+ nvidia-cuda-cupti-cu12==12.1.105
35
+ nvidia-cuda-nvrtc-cu12==12.1.105
36
+ nvidia-cuda-runtime-cu12==12.1.105
37
+ nvidia-cudnn-cu12==8.9.2.26
38
+ nvidia-cufft-cu12==11.0.2.54
39
+ nvidia-curand-cu12==10.3.2.106
40
+ nvidia-cusolver-cu12==11.4.5.107
41
+ nvidia-cusparse-cu12==12.1.0.106
42
+ nvidia-nccl-cu12==2.18.1
43
+ nvidia-nvjitlink-cu12==12.3.101
44
+ nvidia-nvtx-cu12==12.1.105
45
+ packaging==23.2
46
+ parso==0.8.3
47
+ pexpect==4.9.0
48
+ pickleshare==0.7.5
49
+ Pillow==10.1.0
50
+ platformdirs==4.1.0
51
+ prompt-toolkit==3.0.43
52
+ psutil==5.9.7
53
+ ptyprocess==0.7.0
54
+ pure-eval==0.2.2
55
+ Pygments==2.17.2
56
+ pyparsing==3.1.1
57
+ python-dateutil==2.8.2
58
+ pyzmq==25.1.2
59
+ requests==2.31.0
60
+ scipy==1.10.1
61
+ six==1.16.0
62
+ stack-data==0.6.3
63
+ sympy==1.12
64
+ torch==2.1.2+cu118
65
+ torchaudio==2.1.2+cu118
66
+ torchvision==0.16.2+cu118
67
+ tornado==6.4
68
+ traitlets==5.14.0
69
+ triton==2.1.0
70
+ typing_extensions==4.9.0
71
+ urllib3==2.1.0
72
+ wcwidth==0.2.12
73
+ zipp==3.17.0
utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import h5py
2
+ import torch
3
+ import torch.utils.data as data
4
+ import torch.multiprocessing
5
+ import scipy.io as sio
6
+ from torch.nn import functional as F
7
+ # torch.multiprocessing.set_start_method('spawn')
8
+
9
+ class H5Dataset(data.Dataset):
10
+ def __init__(self, H5Path):
11
+ super(H5Dataset, self).__init__()
12
+ self.H5File = h5py.File(H5Path,'r')
13
+ self.LeftData = self.H5File['LeftData']
14
+ self.RightData = self.H5File['RightData']
15
+ #self.LeftMask = self.H5File['LeftMask'][:] # update 2024.01.11 Masks loaded separately
16
+ #self.RightMask = self.H5File['RightMask'][:]
17
+
18
+ def __getitem__(self, index):
19
+ return (torch.from_numpy(self.LeftData[index,:,:,:]).float(),
20
+ torch.from_numpy(self.RightData[index,:,:,:]).float())
21
+
22
+ def __len__(self):
23
+ return self.LeftData.shape[0]
24
+
25
+ def save_image_mat(img_r, img_l, result_path, idx):
26
+ save_data = {}
27
+ save_data['recon_L'] = img_l.detach().cpu().numpy()
28
+ save_data['recon_R'] = img_r.detach().cpu().numpy()
29
+ sio.savemat(result_path+'img{}.mat'.format(idx), save_data)
30
+
31
+ def load_dataset(data_path, batch_size):
32
+ kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
33
+ train_dir = data_path + '_train.h5'
34
+ val_dir = data_path + '_val.h5'
35
+ train_set = H5Dataset(train_dir)
36
+ val_set = H5Dataset(val_dir)
37
+ train_loader = torch.utils.data.DataLoader(train_set,batch_size=batch_size, shuffle=False, **kwargs)
38
+ val_loader = torch.utils.data.DataLoader(val_set,batch_size=batch_size, shuffle=False, **kwargs)
39
+ return train_loader, val_loader
40
+
41
+ def load_dataset_test(data_path, batch_size):
42
+ kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
43
+ test_dir = data_path + '.h5'
44
+ test_set = H5Dataset(test_dir)
45
+ test_loader = torch.utils.data.DataLoader(test_set,batch_size=batch_size, shuffle=False, **kwargs)
46
+ return test_loader
47
+
48
+ # loss function # update 20240109 mask out zeros
49
+ def loss_function(xL, xR, x_recon_L, x_recon_R, mu, logvar, beta, left_mask, right_mask):
50
+
51
+ Image_Size=xL.size(3)
52
+
53
+ beta/=Image_Size**2
54
+
55
+ # print('====> Image_Size: {} Beta: {:.8f}'.format(Image_Size, beta))
56
+
57
+ # R_batch_size=xR.size(0)
58
+ # Tutorial on VAE Page-14
59
+ # log[P(X|z)] = C - \frac{1}{2} ||X-f(z)||^2 // \sigma^2
60
+ # = C - \frac{1}{2} \sum_{i=1}^{N} ||X^{(i)}-f(z^{(i)}||^2 // \sigma^2
61
+ # = C - \farc{1}{2} N * F.mse_loss(Xhat-Xtrue) // \sigma^2
62
+ # log[P(X|z)]-C = - \frac{1}{2}*2*192*192//\sigma^2 * F.mse_loss
63
+ # Therefore, vae_beta = \frac{1}{36864//\sigma^2}
64
+
65
+ # mask out zeros
66
+ valid_mask_L = xL!=0
67
+ valid_mask_R = xR!=0
68
+
69
+ if left_mask is not None:
70
+ valid_mask_L = valid_mask_L & (left_mask.detach().to(torch.int32)==1)
71
+ valid_mask_R = valid_mask_R & (right_mask.detach().to(torch.int32)==1)
72
+
73
+ MSE_L = F.mse_loss(x_recon_L*valid_mask_L, xL*valid_mask_L, size_average=True)
74
+ MSE_R = F.mse_loss(x_recon_R*valid_mask_R, xR *valid_mask_R, size_average=True)
75
+
76
+ # KLD is averaged across batch-samples
77
+ KLD = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(1).mean()
78
+
79
+ return KLD * beta + MSE_L + MSE_R
80
+