soumickmj commited on
Commit
301db00
·
verified ·
1 Parent(s): f368560

Upload UNetFTSR

Browse files
Files changed (5) hide show
  1. UNetFTSR.py +19 -0
  2. UNetFTSRConfig.py +25 -0
  3. config.json +20 -0
  4. model.safetensors +3 -0
  5. unet3D.py +137 -0
UNetFTSR.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from .unet3D import UNet
3
+ from .UNetFTSRConfig import UNetFTSRConfig
4
+
5
+ class UNetFTSR(PreTrainedModel):
6
+ config_class = UNetFTSRConfig
7
+ def __init__(self, config):
8
+ super().__init__(config)
9
+ self.model = UNet(
10
+ in_channels=config.in_channels,
11
+ n_classes=config.n_classes,
12
+ depth=config.depth,
13
+ wf=config.wf,
14
+ padding=config.padding,
15
+ batch_norm=config.batch_norm,
16
+ up_mode=config.up_mode,
17
+ dropout=config.dropout)
18
+ def forward(self, x):
19
+ return self.model(x)
UNetFTSRConfig.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+ class UNetFTSRConfig(PretrainedConfig):
5
+ model_type = "UNet"
6
+ def __init__(
7
+ self,
8
+ in_channels=1,
9
+ n_classes=1,
10
+ depth=3,
11
+ wf=6,
12
+ padding=True,
13
+ batch_norm=False,
14
+ up_mode='upconv',
15
+ dropout=False,
16
+ **kwargs):
17
+ self.in_channels = in_channels
18
+ self.n_classes = n_classes
19
+ self.depth = depth
20
+ self.wf = wf
21
+ self.padding = padding
22
+ self.batch_norm = batch_norm
23
+ self.up_mode = up_mode
24
+ self.dropout = dropout
25
+ super().__init__(**kwargs)
config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "UNetFTSR"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "UNetFTSRConfig.UNetFTSRConfig",
7
+ "AutoModel": "UNetFTSR.UNetFTSR"
8
+ },
9
+ "batch_norm": false,
10
+ "depth": 3,
11
+ "dropout": false,
12
+ "in_channels": 1,
13
+ "model_type": "UNet",
14
+ "n_classes": 1,
15
+ "padding": true,
16
+ "torch_dtype": "float32",
17
+ "transformers_version": "4.44.2",
18
+ "up_mode": "upconv",
19
+ "wf": 6
20
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6135269f442ff8de0dc58f670de1e65ac5fb98de769d20a03a6e64fcf7bbd6e
3
+ size 21675396
unet3D.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This model is part of the paper "Fine-tuning deep learning model parameters for improved super-resolution of dynamic MRI with prior-knowledge" (https://doi.org/10.1016/j.artmed.2021.102196)
2
+ # and has been published on GitHub: https://github.com/soumickmj/FTSuperResDynMRI/blob/main/models/unet3D.py
3
+
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+ __author__ = "Soumick Chatterjee, Chompunuch Sarasaen"
10
+ __copyright__ = "Copyright 2020, Faculty of Computer Science, Otto von Guericke University Magdeburg, Germany"
11
+ __credits__ = ["Soumick Chatterjee", "Chompunuch Sarasaen"]
12
+ __license__ = "GPL"
13
+ __version__ = "1.0.0"
14
+ __maintainer__ = "Soumick Chatterjee"
15
+ __email__ = "[email protected]"
16
+ __status__ = "Published"
17
+
18
+
19
+ class UNet(nn.Module):
20
+ """
21
+ Implementation of
22
+ U-Net: Convolutional Networks for Biomedical Image Segmentation
23
+ (Ronneberger et al., 2015)
24
+ https://arxiv.org/abs/1505.04597
25
+
26
+ Using the default arguments will yield the exact version used
27
+ in the original paper
28
+
29
+ Adapted from https://discuss.pytorch.org/t/unet-implementation/426
30
+
31
+ Args:
32
+ in_channels (int): number of input channels
33
+ n_classes (int): number of output channels
34
+ depth (int): depth of the network
35
+ wf (int): number of filters in the first layer is 2**wf
36
+ padding (bool): if True, apply padding such that the input shape
37
+ is the same as the output.
38
+ This may introduce artifacts
39
+ batch_norm (bool): Use BatchNorm after layers with an
40
+ activation function
41
+ up_mode (str): one of 'upconv' or 'upsample'.
42
+ 'upconv' will use transposed convolutions for
43
+ learned upsampling.
44
+ 'upsample' will use bilinear upsampling.
45
+ """
46
+ def __init__(self, in_channels=1, n_classes=1, depth=3, wf=6, padding=True,
47
+ batch_norm=False, up_mode='upconv', dropout=False):
48
+ super(UNet, self).__init__()
49
+ assert up_mode in ('upconv', 'upsample')
50
+ self.padding = padding
51
+ self.depth = depth
52
+ self.dropout = nn.Dropout3d() if dropout else nn.Sequential()
53
+ prev_channels = in_channels
54
+ self.down_path = nn.ModuleList()
55
+ for i in range(depth):
56
+ self.down_path.append(UNetConvBlock(prev_channels, 2**(wf+i),
57
+ padding, batch_norm))
58
+ prev_channels = 2**(wf+i)
59
+
60
+ self.up_path = nn.ModuleList()
61
+ for i in reversed(range(depth - 1)):
62
+ self.up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode,
63
+ padding, batch_norm))
64
+ prev_channels = 2**(wf+i)
65
+
66
+ self.last = nn.Conv3d(prev_channels, n_classes, kernel_size=1)
67
+
68
+ def forward(self, x):
69
+ blocks = []
70
+ for i, down in enumerate(self.down_path):
71
+ x = down(x)
72
+ if i != len(self.down_path)-1:
73
+ blocks.append(x)
74
+ x = F.avg_pool3d(x, 2)
75
+ x = self.dropout(x)
76
+
77
+ for i, up in enumerate(self.up_path):
78
+ x = up(x, blocks[-i-1])
79
+
80
+ return self.last(x)
81
+
82
+
83
+ class UNetConvBlock(nn.Module):
84
+ def __init__(self, in_size, out_size, padding, batch_norm):
85
+ super(UNetConvBlock, self).__init__()
86
+ block = []
87
+
88
+ block.append(nn.Conv3d(in_size, out_size, kernel_size=3,
89
+ padding=int(padding)))
90
+ block.append(nn.ReLU())
91
+ if batch_norm:
92
+ block.append(nn.BatchNorm3d(out_size))
93
+
94
+ block.append(nn.Conv3d(out_size, out_size, kernel_size=3,
95
+ padding=int(padding)))
96
+ block.append(nn.ReLU())
97
+ if batch_norm:
98
+ block.append(nn.BatchNorm3d(out_size))
99
+
100
+ self.block = nn.Sequential(*block)
101
+
102
+ def forward(self, x):
103
+ out = self.block(x)
104
+ return out
105
+
106
+
107
+ class UNetUpBlock(nn.Module):
108
+ def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
109
+ super(UNetUpBlock, self).__init__()
110
+ if up_mode == 'upconv':
111
+ self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=2,
112
+ stride=2)
113
+ elif up_mode == 'upsample':
114
+ self.up = nn.Sequential(nn.Upsample(mode='trilinear', scale_factor=2),
115
+ nn.Conv3d(in_size, out_size, kernel_size=1))
116
+
117
+ self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)
118
+
119
+ def center_crop(self, layer, target_size):
120
+ _, _, layer_depth, layer_height, layer_width = layer.size()
121
+ diff_z = (layer_depth - target_size[0]) // 2
122
+ diff_y = (layer_height - target_size[1]) // 2
123
+ diff_x = (layer_width - target_size[2]) // 2
124
+ return layer[:, :, diff_z:(diff_z + target_size[0]), diff_y:(diff_y + target_size[1]), diff_x:(diff_x + target_size[2])]
125
+ # _, _, layer_height, layer_width = layer.size() #for 2D data
126
+ # diff_y = (layer_height - target_size[0]) // 2
127
+ # diff_x = (layer_width - target_size[1]) // 2
128
+ # return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])]
129
+
130
+ def forward(self, x, bridge):
131
+ up = self.up(x)
132
+ # bridge = self.center_crop(bridge, up.shape[2:]) #sending shape ignoring 2 digit, so target size start with 0,1,2
133
+ up = F.interpolate(up, size=bridge.shape[2:], mode='trilinear')
134
+ out = torch.cat([up, bridge], 1)
135
+ out = self.conv_block(out)
136
+
137
+ return out