Basic files for implementing trained model
Browse files- Checkpoint/checkpoint49_2024-03-28_Zdim_2_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar +3 -0
- Checkpoint/checkpoint49_2024-06-21_Zdim_4_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar +3 -0
- Checkpoint/checkpoint49_2024-11-28_Zdim_3_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar +3 -0
- VAE_inference_example.py +83 -0
- call_VAE_inference_example.sh +20 -0
- fMRIVAE_Model.py +140 -0
- mask/Left_fMRI2Grid_192_by_192_NN.mat +0 -0
- mask/MSE_Mask.mat +0 -0
- mask/Right_fMRI2Grid_192_by_192_NN.mat +0 -0
- requirements.txt +73 -0
- utils.py +80 -0
Checkpoint/checkpoint49_2024-03-28_Zdim_2_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:81de496d81529e84942d3ca6652e578ad412712044e472bcd01635e0533609f3
|
3 |
+
size 48127946
|
Checkpoint/checkpoint49_2024-06-21_Zdim_4_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:93e6aac3dc3eeb3296ce2e24a480e5b704b85beb32360a31953409fb7dfc00ac
|
3 |
+
size 48791690
|
Checkpoint/checkpoint49_2024-11-28_Zdim_3_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f48320b2e5a4ae6d668e22284b50470d8444dd99cd82427ad0ccead34fdd9bbd
|
3 |
+
size 48459722
|
VAE_inference_example.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch # tested on version 2.1.2+cu118
|
2 |
+
import scipy.io as io
|
3 |
+
import argparse
|
4 |
+
import logging
|
5 |
+
from utils import load_dataset_test, save_image_mat
|
6 |
+
from fMRIVAE_Model import BetaVAE
|
7 |
+
import os
|
8 |
+
|
9 |
+
def main():
|
10 |
+
|
11 |
+
parser = argparse.ArgumentParser(description='VAE for fMRI generation')
|
12 |
+
parser.add_argument('--batch-size', type=int, metavar='N', help='how many samples per saved file?')
|
13 |
+
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
|
14 |
+
parser.add_argument('--zdim', type=int, default=256, metavar='N', help='dimension of latent variables')
|
15 |
+
parser.add_argument('--data-path', type=str, metavar='DIR', help='path to dataset')
|
16 |
+
parser.add_argument('--z-path', type=str, default='./result/latent/', help='path to saved z files')
|
17 |
+
parser.add_argument('--resume', type=str, default='./checkpoint/checkpoint.pth.tar', help='the VAE checkpoint')
|
18 |
+
parser.add_argument('--img-path', type=str, default='./result/recon', help='path to save reconstructed images')
|
19 |
+
parser.add_argument('--mode', type=str, default='both', help='choose from \'encode\',\'decode\' or \'both\'')
|
20 |
+
parser.add_argument('--debug', action='store_true', help='Enable debug mode for detailed logging')
|
21 |
+
|
22 |
+
args = parser.parse_args()
|
23 |
+
|
24 |
+
if not os.path.isdir(args.z_path):
|
25 |
+
os.system('mkdir '+ args.z_path + ' -p')
|
26 |
+
if (args.mode != 'encode') and not os.path.isdir(args.img_path):
|
27 |
+
os.system('mkdir '+ args.img_path + ' -p')
|
28 |
+
|
29 |
+
# Set logging level based on debug flag
|
30 |
+
logging_level = logging.DEBUG if args.debug else logging.INFO
|
31 |
+
logging.basicConfig(level=logging_level, format='%(asctime)s - %(levelname)s - %(message)s')
|
32 |
+
|
33 |
+
logging.debug("Starting the VAE inference script.")
|
34 |
+
args = parser.parse_args()
|
35 |
+
logging.debug(f"Parsed arguments: {args}")
|
36 |
+
|
37 |
+
try:
|
38 |
+
torch.manual_seed(args.seed)
|
39 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
40 |
+
logging.debug(f"Using device: {device}")
|
41 |
+
|
42 |
+
logging.debug(f"Loading VAE model from {args.resume}.")
|
43 |
+
model = BetaVAE(z_dim=args.zdim, nc=1).to(device)
|
44 |
+
if os.path.isfile(args.resume):
|
45 |
+
checkpoint = torch.load(args.resume, map_location=device)
|
46 |
+
model.load_state_dict(checkpoint['state_dict'])
|
47 |
+
logging.debug("Checkpoint loaded.")
|
48 |
+
else:
|
49 |
+
logging.error(f"Checkpoint not found at {args.resume}")
|
50 |
+
raise RuntimeError("Checkpoint not found.")
|
51 |
+
|
52 |
+
if (args.mode == 'encode') or (args.mode == 'both'):
|
53 |
+
logging.debug("Starting encoding process...")
|
54 |
+
test_loader = load_dataset_test(args.data_path, args.batch_size)
|
55 |
+
logging.debug(f"Loaded test dataset from {args.data_path}")
|
56 |
+
for batch_idx, (xL, xR) in enumerate(test_loader):
|
57 |
+
xL = xL.to(device)
|
58 |
+
xR = xR.to(device)
|
59 |
+
z_distribution = model._encode(xL, xR)
|
60 |
+
save_data = {'z_distribution': z_distribution.detach().cpu().numpy()}
|
61 |
+
io.savemat(os.path.join(args.z_path, f'save_z{batch_idx}.mat'), save_data)
|
62 |
+
logging.debug(f"Encoded batch {batch_idx}")
|
63 |
+
|
64 |
+
if (args.mode == 'decode') or (args.mode == 'both'):
|
65 |
+
logging.debug("Starting decoding process...")
|
66 |
+
filelist = [f for f in os.listdir(args.z_path) if f.split('_')[0] == 'save']
|
67 |
+
logging.debug(f"Filelist: {filelist}")
|
68 |
+
for batch_idx, filename in enumerate(filelist):
|
69 |
+
logging.debug(f"Decoding file {filename}")
|
70 |
+
z_dist = io.loadmat(os.path.join(args.z_path, f'save_z{batch_idx}.mat'))
|
71 |
+
z_dist = z_dist['z_distribution']
|
72 |
+
mu = z_dist[:, :args.zdim]
|
73 |
+
z = torch.tensor(mu).to(device)
|
74 |
+
x_recon_L, x_recon_R = model._decode(z)
|
75 |
+
save_image_mat(x_recon_R, x_recon_L, args.img_path, batch_idx)
|
76 |
+
logging.debug(f"Decoded and saved batch {batch_idx}")
|
77 |
+
|
78 |
+
except Exception as e:
|
79 |
+
logging.error(f"An error occurred: {e}")
|
80 |
+
raise
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
main()
|
call_VAE_inference_example.sh
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
SECONDS=0
|
3 |
+
|
4 |
+
# MSC
|
5 |
+
subjs=(01 02 03 04 05 06 07 08 09 10)
|
6 |
+
parcelcount=(602 567 620 616 633 580 628 710 613 649)
|
7 |
+
zdim=2
|
8 |
+
checkpoint="./VAE_Model/Checkpoint/checkpoint49_2024-03-28_Zdim_2_Vae-beta_20.0_Lr_0.0001_Batch-size_128_washu120_subsample10_train100_val10.pth.tar"
|
9 |
+
for i in "${!subjs[@]}"; do
|
10 |
+
subj="${subjs[$i]}"
|
11 |
+
curr_parcel_count="${parcelcount[$i]}"
|
12 |
+
echo $curr_parcel_count
|
13 |
+
namestr="sub-MSC${subj}_sub-MSC${subj}Parcel"
|
14 |
+
|
15 |
+
python3 VAE_inference_example.py --data-path ./data/$namestr --zdim $zdim \
|
16 |
+
--resume "${checkpoint}" \
|
17 |
+
--z-path './result/latent/'$namestr'_Zdim'$zdim --mode 'encode' --batch-size $curr_parcel_count
|
18 |
+
|
19 |
+
echo "The command took $SECONDS seconds."
|
20 |
+
done
|
fMRIVAE_Model.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""model.py"""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.nn.init as init
|
7 |
+
from collections import Iterable
|
8 |
+
from torch.autograd import Variable
|
9 |
+
|
10 |
+
class BetaVAE(nn.Module):
|
11 |
+
"""Model proposed in original beta-VAE paper(Higgins et al, ICLR, 2017)."""
|
12 |
+
|
13 |
+
def __init__(self, z_dim=64, nc=1, cirpad_dire=(False, True)):
|
14 |
+
super(BetaVAE, self).__init__()
|
15 |
+
self.z_dim = z_dim
|
16 |
+
self.nc = nc
|
17 |
+
self.cirpad_dire = cirpad_dire
|
18 |
+
|
19 |
+
self.ocs = [64, 128, 128, 256, 256]
|
20 |
+
self.nLays = len(self.ocs)
|
21 |
+
self.topW = int(192/2**self.nLays)
|
22 |
+
|
23 |
+
# encoder
|
24 |
+
self.ConvL = nn.Conv2d(1,int(self.ocs[0]/2),8,2,0) # pad=3, only in forward
|
25 |
+
self.ConvR = nn.Conv2d(1,int(self.ocs[0]/2),8,2,0) # pad=3, only in forward # B, 128, 96, 96
|
26 |
+
self.EncConvs = nn.ModuleList([nn.Conv2d(self.ocs[i-1], self.ocs[i], 4, 2, 0) for i in range(1, self.nLays)]) # pad=1 only in forward
|
27 |
+
self.fc1 = nn.Linear(self.ocs[-1]*self.topW**2, z_dim*2)
|
28 |
+
|
29 |
+
# decoder
|
30 |
+
self.fc2 = nn.Linear(z_dim, self.ocs[-1]*self.topW**2)
|
31 |
+
self.DecConvs = nn.ModuleList([nn.ConvTranspose2d(self.ocs[i], self.ocs[i-1], 4, 2, 3) for i in range(4,0,-1)]) # pad=1; dilation * (kernel_size - 1) - padding = 6 (later in forward)
|
32 |
+
self.tConvL = nn.ConvTranspose2d(int(self.ocs[0]/2), nc, 8, 2, 9) # pad=3 later; dilation * (kernel_size - 1) - padding = 4 (later in forward)
|
33 |
+
self.tConvR = nn.ConvTranspose2d(int(self.ocs[0]/2), nc, 8, 2, 9) # pad=3 later
|
34 |
+
|
35 |
+
self.relu = nn.ReLU(inplace=True)
|
36 |
+
|
37 |
+
self.weight_init()
|
38 |
+
|
39 |
+
def cirpad(self, x, padding, cirpad_dire):
|
40 |
+
# x is input
|
41 |
+
# padding is the size of pading
|
42 |
+
# cirpad_dire is (last_dim_pad, second_to_last_dim_pad)
|
43 |
+
|
44 |
+
# >>> t4d = torch.empty(3, 3, 4, 2)
|
45 |
+
# >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2)
|
46 |
+
# >>> out = F.pad(t4d, p2d, "constant", 0)
|
47 |
+
# >>> print(out.size())
|
48 |
+
# torch.Size([3, 3, 8, 4])
|
49 |
+
|
50 |
+
# last dim
|
51 |
+
if cirpad_dire[0] is True:
|
52 |
+
x = F.pad(x, (padding, padding, 0, 0), 'circular')
|
53 |
+
else:
|
54 |
+
x = F.pad(x, (padding, padding, 0, 0), "constant", 0)
|
55 |
+
|
56 |
+
# second last dim
|
57 |
+
if cirpad_dire[1] is True:
|
58 |
+
x = F.pad(x, (0, 0, padding, padding), 'circular')
|
59 |
+
else:
|
60 |
+
x = F.pad(x, (0, 0, padding, padding), "constant", 0)
|
61 |
+
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
def weight_init(self):
|
66 |
+
for block in self._modules:
|
67 |
+
if isinstance(self._modules[block], Iterable):
|
68 |
+
for m in self._modules[block]:
|
69 |
+
m.apply(kaiming_init)
|
70 |
+
else:
|
71 |
+
self._modules[block].apply(kaiming_init)
|
72 |
+
|
73 |
+
def _encode(self, xL, xR):
|
74 |
+
xL = self.cirpad(xL, 3, self.cirpad_dire)
|
75 |
+
xR = self.cirpad(xR, 3, self.cirpad_dire)
|
76 |
+
x = torch.cat((self.ConvL(xL), self.ConvR(xR)), 1)
|
77 |
+
x = self.relu(x)
|
78 |
+
for lay in range(self.nLays-1):
|
79 |
+
x = self.cirpad(x, 1, self.cirpad_dire)
|
80 |
+
x = self.relu(self.EncConvs[lay](x))
|
81 |
+
x = x.view(-1, self.ocs[-1]*self.topW*self.topW)
|
82 |
+
x = self.fc1(x)
|
83 |
+
return x
|
84 |
+
|
85 |
+
def _decode(self, z):
|
86 |
+
x = self.relu(self.fc2(z).view(-1 , self.ocs[-1], self.topW, self.topW))
|
87 |
+
for lay in range(self.nLays-1):
|
88 |
+
x = self.cirpad(x, 1, self.cirpad_dire)
|
89 |
+
x = self.relu(self.DecConvs[lay](x))
|
90 |
+
|
91 |
+
xL, xR = torch.chunk(x, 2, dim=1)
|
92 |
+
|
93 |
+
xrL = self.tConvL(self.cirpad(xL, 3, self.cirpad_dire))
|
94 |
+
|
95 |
+
xrR = self.tConvR(self.cirpad(xR, 3, self.cirpad_dire))
|
96 |
+
return xrL, xrR
|
97 |
+
|
98 |
+
def reparametrize(self, mu, logvar):
|
99 |
+
std = logvar.div(2).exp()
|
100 |
+
eps = Variable(std.data.new(std.size()).normal_())
|
101 |
+
return mu + std*eps
|
102 |
+
|
103 |
+
|
104 |
+
def forward(self, xL, xR):
|
105 |
+
distributions = self._encode(xL, xR)
|
106 |
+
mu = distributions[:, :self.z_dim]
|
107 |
+
logvar = distributions[:, self.z_dim:]
|
108 |
+
z = self.reparametrize(mu, logvar)
|
109 |
+
x_recon_L, x_recon_R = self._decode(z)
|
110 |
+
return x_recon_L, x_recon_R, mu, logvar
|
111 |
+
|
112 |
+
def kaiming_init(m):
|
113 |
+
if isinstance(m, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): # Shall we apply init to ConvTranspose2d?
|
114 |
+
init.kaiming_normal_(m.weight)
|
115 |
+
if m.bias is not None:
|
116 |
+
m.bias.data.fill_(0)
|
117 |
+
elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
|
118 |
+
m.weight.data.fill_(1)
|
119 |
+
if m.bias is not None:
|
120 |
+
m.bias.data.fill_(0)
|
121 |
+
|
122 |
+
def normal_init(m, mean, std):
|
123 |
+
if isinstance(m, (nn.Linear, nn.Conv2d)):
|
124 |
+
m.weight.data.normal_(mean, std)
|
125 |
+
if m.bias.data is not None:
|
126 |
+
m.bias.data.zero_()
|
127 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm1d)):
|
128 |
+
m.weight.data.fill_(1)
|
129 |
+
if m.bias.data is not None:
|
130 |
+
m.bias.data.zero_()
|
131 |
+
|
132 |
+
#if __name__ == "__main__":
|
133 |
+
# m = BetaVAE_H()
|
134 |
+
# a=torch.ones(1,1,192,192)
|
135 |
+
# out1, out2, _, _ = m(a,a)
|
136 |
+
# print(out1.size())
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
|
mask/Left_fMRI2Grid_192_by_192_NN.mat
ADDED
Binary file (199 kB). View file
|
|
mask/MSE_Mask.mat
ADDED
Binary file (1.21 kB). View file
|
|
mask/Right_fMRI2Grid_192_by_192_NN.mat
ADDED
Binary file (199 kB). View file
|
|
requirements.txt
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
asttokens==2.4.1
|
2 |
+
backcall==0.2.0
|
3 |
+
certifi==2023.11.17
|
4 |
+
charset-normalizer==3.3.2
|
5 |
+
comm==0.2.0
|
6 |
+
contourpy==1.1.1
|
7 |
+
cycler==0.12.1
|
8 |
+
debugpy==1.8.0
|
9 |
+
decorator==5.1.1
|
10 |
+
executing==2.0.1
|
11 |
+
filelock==3.13.1
|
12 |
+
fonttools==4.47.0
|
13 |
+
fsspec==2023.12.2
|
14 |
+
h5py==3.10.0
|
15 |
+
idna==3.6
|
16 |
+
importlib-metadata==7.0.0
|
17 |
+
importlib-resources==6.1.1
|
18 |
+
ipykernel==6.27.1
|
19 |
+
ipython==8.12.3
|
20 |
+
jedi==0.19.1
|
21 |
+
Jinja2==3.1.2
|
22 |
+
jupyter_client==8.6.0
|
23 |
+
jupyter_core==5.5.1
|
24 |
+
kiwisolver==1.4.5
|
25 |
+
MarkupSafe==2.1.3
|
26 |
+
matplotlib==3.7.4
|
27 |
+
matplotlib-inline==0.1.6
|
28 |
+
mpmath==1.3.0
|
29 |
+
nest-asyncio==1.5.8
|
30 |
+
networkx==3.1
|
31 |
+
nibabel==5.2.0
|
32 |
+
numpy==1.24.4
|
33 |
+
nvidia-cublas-cu12==12.1.3.1
|
34 |
+
nvidia-cuda-cupti-cu12==12.1.105
|
35 |
+
nvidia-cuda-nvrtc-cu12==12.1.105
|
36 |
+
nvidia-cuda-runtime-cu12==12.1.105
|
37 |
+
nvidia-cudnn-cu12==8.9.2.26
|
38 |
+
nvidia-cufft-cu12==11.0.2.54
|
39 |
+
nvidia-curand-cu12==10.3.2.106
|
40 |
+
nvidia-cusolver-cu12==11.4.5.107
|
41 |
+
nvidia-cusparse-cu12==12.1.0.106
|
42 |
+
nvidia-nccl-cu12==2.18.1
|
43 |
+
nvidia-nvjitlink-cu12==12.3.101
|
44 |
+
nvidia-nvtx-cu12==12.1.105
|
45 |
+
packaging==23.2
|
46 |
+
parso==0.8.3
|
47 |
+
pexpect==4.9.0
|
48 |
+
pickleshare==0.7.5
|
49 |
+
Pillow==10.1.0
|
50 |
+
platformdirs==4.1.0
|
51 |
+
prompt-toolkit==3.0.43
|
52 |
+
psutil==5.9.7
|
53 |
+
ptyprocess==0.7.0
|
54 |
+
pure-eval==0.2.2
|
55 |
+
Pygments==2.17.2
|
56 |
+
pyparsing==3.1.1
|
57 |
+
python-dateutil==2.8.2
|
58 |
+
pyzmq==25.1.2
|
59 |
+
requests==2.31.0
|
60 |
+
scipy==1.10.1
|
61 |
+
six==1.16.0
|
62 |
+
stack-data==0.6.3
|
63 |
+
sympy==1.12
|
64 |
+
torch==2.1.2+cu118
|
65 |
+
torchaudio==2.1.2+cu118
|
66 |
+
torchvision==0.16.2+cu118
|
67 |
+
tornado==6.4
|
68 |
+
traitlets==5.14.0
|
69 |
+
triton==2.1.0
|
70 |
+
typing_extensions==4.9.0
|
71 |
+
urllib3==2.1.0
|
72 |
+
wcwidth==0.2.12
|
73 |
+
zipp==3.17.0
|
utils.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import h5py
|
2 |
+
import torch
|
3 |
+
import torch.utils.data as data
|
4 |
+
import torch.multiprocessing
|
5 |
+
import scipy.io as sio
|
6 |
+
from torch.nn import functional as F
|
7 |
+
# torch.multiprocessing.set_start_method('spawn')
|
8 |
+
|
9 |
+
class H5Dataset(data.Dataset):
|
10 |
+
def __init__(self, H5Path):
|
11 |
+
super(H5Dataset, self).__init__()
|
12 |
+
self.H5File = h5py.File(H5Path,'r')
|
13 |
+
self.LeftData = self.H5File['LeftData']
|
14 |
+
self.RightData = self.H5File['RightData']
|
15 |
+
#self.LeftMask = self.H5File['LeftMask'][:] # update 2024.01.11 Masks loaded separately
|
16 |
+
#self.RightMask = self.H5File['RightMask'][:]
|
17 |
+
|
18 |
+
def __getitem__(self, index):
|
19 |
+
return (torch.from_numpy(self.LeftData[index,:,:,:]).float(),
|
20 |
+
torch.from_numpy(self.RightData[index,:,:,:]).float())
|
21 |
+
|
22 |
+
def __len__(self):
|
23 |
+
return self.LeftData.shape[0]
|
24 |
+
|
25 |
+
def save_image_mat(img_r, img_l, result_path, idx):
|
26 |
+
save_data = {}
|
27 |
+
save_data['recon_L'] = img_l.detach().cpu().numpy()
|
28 |
+
save_data['recon_R'] = img_r.detach().cpu().numpy()
|
29 |
+
sio.savemat(result_path+'img{}.mat'.format(idx), save_data)
|
30 |
+
|
31 |
+
def load_dataset(data_path, batch_size):
|
32 |
+
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
|
33 |
+
train_dir = data_path + '_train.h5'
|
34 |
+
val_dir = data_path + '_val.h5'
|
35 |
+
train_set = H5Dataset(train_dir)
|
36 |
+
val_set = H5Dataset(val_dir)
|
37 |
+
train_loader = torch.utils.data.DataLoader(train_set,batch_size=batch_size, shuffle=False, **kwargs)
|
38 |
+
val_loader = torch.utils.data.DataLoader(val_set,batch_size=batch_size, shuffle=False, **kwargs)
|
39 |
+
return train_loader, val_loader
|
40 |
+
|
41 |
+
def load_dataset_test(data_path, batch_size):
|
42 |
+
kwargs = {'num_workers': 1, 'pin_memory': True} if torch.cuda.is_available() else {}
|
43 |
+
test_dir = data_path + '.h5'
|
44 |
+
test_set = H5Dataset(test_dir)
|
45 |
+
test_loader = torch.utils.data.DataLoader(test_set,batch_size=batch_size, shuffle=False, **kwargs)
|
46 |
+
return test_loader
|
47 |
+
|
48 |
+
# loss function # update 20240109 mask out zeros
|
49 |
+
def loss_function(xL, xR, x_recon_L, x_recon_R, mu, logvar, beta, left_mask, right_mask):
|
50 |
+
|
51 |
+
Image_Size=xL.size(3)
|
52 |
+
|
53 |
+
beta/=Image_Size**2
|
54 |
+
|
55 |
+
# print('====> Image_Size: {} Beta: {:.8f}'.format(Image_Size, beta))
|
56 |
+
|
57 |
+
# R_batch_size=xR.size(0)
|
58 |
+
# Tutorial on VAE Page-14
|
59 |
+
# log[P(X|z)] = C - \frac{1}{2} ||X-f(z)||^2 // \sigma^2
|
60 |
+
# = C - \frac{1}{2} \sum_{i=1}^{N} ||X^{(i)}-f(z^{(i)}||^2 // \sigma^2
|
61 |
+
# = C - \farc{1}{2} N * F.mse_loss(Xhat-Xtrue) // \sigma^2
|
62 |
+
# log[P(X|z)]-C = - \frac{1}{2}*2*192*192//\sigma^2 * F.mse_loss
|
63 |
+
# Therefore, vae_beta = \frac{1}{36864//\sigma^2}
|
64 |
+
|
65 |
+
# mask out zeros
|
66 |
+
valid_mask_L = xL!=0
|
67 |
+
valid_mask_R = xR!=0
|
68 |
+
|
69 |
+
if left_mask is not None:
|
70 |
+
valid_mask_L = valid_mask_L & (left_mask.detach().to(torch.int32)==1)
|
71 |
+
valid_mask_R = valid_mask_R & (right_mask.detach().to(torch.int32)==1)
|
72 |
+
|
73 |
+
MSE_L = F.mse_loss(x_recon_L*valid_mask_L, xL*valid_mask_L, size_average=True)
|
74 |
+
MSE_R = F.mse_loss(x_recon_R*valid_mask_R, xR *valid_mask_R, size_average=True)
|
75 |
+
|
76 |
+
# KLD is averaged across batch-samples
|
77 |
+
KLD = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(1).mean()
|
78 |
+
|
79 |
+
return KLD * beta + MSE_L + MSE_R
|
80 |
+
|