Upload UNetFTSR
Browse files- UNetFTSR.py +19 -0
- UNetFTSRConfig.py +25 -0
- config.json +20 -0
- model.safetensors +3 -0
- unet3D.py +137 -0
UNetFTSR.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel
|
2 |
+
from .unet3D import UNet
|
3 |
+
from .UNetFTSRConfig import UNetFTSRConfig
|
4 |
+
|
5 |
+
class UNetFTSR(PreTrainedModel):
|
6 |
+
config_class = UNetFTSRConfig
|
7 |
+
def __init__(self, config):
|
8 |
+
super().__init__(config)
|
9 |
+
self.model = UNet(
|
10 |
+
in_channels=config.in_channels,
|
11 |
+
n_classes=config.n_classes,
|
12 |
+
depth=config.depth,
|
13 |
+
wf=config.wf,
|
14 |
+
padding=config.padding,
|
15 |
+
batch_norm=config.batch_norm,
|
16 |
+
up_mode=config.up_mode,
|
17 |
+
dropout=config.dropout)
|
18 |
+
def forward(self, x):
|
19 |
+
return self.model(x)
|
UNetFTSRConfig.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
class UNetFTSRConfig(PretrainedConfig):
|
5 |
+
model_type = "UNet"
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
in_channels=1,
|
9 |
+
n_classes=1,
|
10 |
+
depth=3,
|
11 |
+
wf=6,
|
12 |
+
padding=True,
|
13 |
+
batch_norm=False,
|
14 |
+
up_mode='upconv',
|
15 |
+
dropout=False,
|
16 |
+
**kwargs):
|
17 |
+
self.in_channels = in_channels
|
18 |
+
self.n_classes = n_classes
|
19 |
+
self.depth = depth
|
20 |
+
self.wf = wf
|
21 |
+
self.padding = padding
|
22 |
+
self.batch_norm = batch_norm
|
23 |
+
self.up_mode = up_mode
|
24 |
+
self.dropout = dropout
|
25 |
+
super().__init__(**kwargs)
|
config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"UNetFTSR"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "UNetFTSRConfig.UNetFTSRConfig",
|
7 |
+
"AutoModel": "UNetFTSR.UNetFTSR"
|
8 |
+
},
|
9 |
+
"batch_norm": false,
|
10 |
+
"depth": 3,
|
11 |
+
"dropout": false,
|
12 |
+
"in_channels": 1,
|
13 |
+
"model_type": "UNet",
|
14 |
+
"n_classes": 1,
|
15 |
+
"padding": true,
|
16 |
+
"torch_dtype": "float32",
|
17 |
+
"transformers_version": "4.44.2",
|
18 |
+
"up_mode": "upconv",
|
19 |
+
"wf": 6
|
20 |
+
}
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c6135269f442ff8de0dc58f670de1e65ac5fb98de769d20a03a6e64fcf7bbd6e
|
3 |
+
size 21675396
|
unet3D.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This model is part of the paper "Fine-tuning deep learning model parameters for improved super-resolution of dynamic MRI with prior-knowledge" (https://doi.org/10.1016/j.artmed.2021.102196)
|
2 |
+
# and has been published on GitHub: https://github.com/soumickmj/FTSuperResDynMRI/blob/main/models/unet3D.py
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
__author__ = "Soumick Chatterjee, Chompunuch Sarasaen"
|
10 |
+
__copyright__ = "Copyright 2020, Faculty of Computer Science, Otto von Guericke University Magdeburg, Germany"
|
11 |
+
__credits__ = ["Soumick Chatterjee", "Chompunuch Sarasaen"]
|
12 |
+
__license__ = "GPL"
|
13 |
+
__version__ = "1.0.0"
|
14 |
+
__maintainer__ = "Soumick Chatterjee"
|
15 |
+
__email__ = "[email protected]"
|
16 |
+
__status__ = "Published"
|
17 |
+
|
18 |
+
|
19 |
+
class UNet(nn.Module):
|
20 |
+
"""
|
21 |
+
Implementation of
|
22 |
+
U-Net: Convolutional Networks for Biomedical Image Segmentation
|
23 |
+
(Ronneberger et al., 2015)
|
24 |
+
https://arxiv.org/abs/1505.04597
|
25 |
+
|
26 |
+
Using the default arguments will yield the exact version used
|
27 |
+
in the original paper
|
28 |
+
|
29 |
+
Adapted from https://discuss.pytorch.org/t/unet-implementation/426
|
30 |
+
|
31 |
+
Args:
|
32 |
+
in_channels (int): number of input channels
|
33 |
+
n_classes (int): number of output channels
|
34 |
+
depth (int): depth of the network
|
35 |
+
wf (int): number of filters in the first layer is 2**wf
|
36 |
+
padding (bool): if True, apply padding such that the input shape
|
37 |
+
is the same as the output.
|
38 |
+
This may introduce artifacts
|
39 |
+
batch_norm (bool): Use BatchNorm after layers with an
|
40 |
+
activation function
|
41 |
+
up_mode (str): one of 'upconv' or 'upsample'.
|
42 |
+
'upconv' will use transposed convolutions for
|
43 |
+
learned upsampling.
|
44 |
+
'upsample' will use bilinear upsampling.
|
45 |
+
"""
|
46 |
+
def __init__(self, in_channels=1, n_classes=1, depth=3, wf=6, padding=True,
|
47 |
+
batch_norm=False, up_mode='upconv', dropout=False):
|
48 |
+
super(UNet, self).__init__()
|
49 |
+
assert up_mode in ('upconv', 'upsample')
|
50 |
+
self.padding = padding
|
51 |
+
self.depth = depth
|
52 |
+
self.dropout = nn.Dropout3d() if dropout else nn.Sequential()
|
53 |
+
prev_channels = in_channels
|
54 |
+
self.down_path = nn.ModuleList()
|
55 |
+
for i in range(depth):
|
56 |
+
self.down_path.append(UNetConvBlock(prev_channels, 2**(wf+i),
|
57 |
+
padding, batch_norm))
|
58 |
+
prev_channels = 2**(wf+i)
|
59 |
+
|
60 |
+
self.up_path = nn.ModuleList()
|
61 |
+
for i in reversed(range(depth - 1)):
|
62 |
+
self.up_path.append(UNetUpBlock(prev_channels, 2**(wf+i), up_mode,
|
63 |
+
padding, batch_norm))
|
64 |
+
prev_channels = 2**(wf+i)
|
65 |
+
|
66 |
+
self.last = nn.Conv3d(prev_channels, n_classes, kernel_size=1)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
blocks = []
|
70 |
+
for i, down in enumerate(self.down_path):
|
71 |
+
x = down(x)
|
72 |
+
if i != len(self.down_path)-1:
|
73 |
+
blocks.append(x)
|
74 |
+
x = F.avg_pool3d(x, 2)
|
75 |
+
x = self.dropout(x)
|
76 |
+
|
77 |
+
for i, up in enumerate(self.up_path):
|
78 |
+
x = up(x, blocks[-i-1])
|
79 |
+
|
80 |
+
return self.last(x)
|
81 |
+
|
82 |
+
|
83 |
+
class UNetConvBlock(nn.Module):
|
84 |
+
def __init__(self, in_size, out_size, padding, batch_norm):
|
85 |
+
super(UNetConvBlock, self).__init__()
|
86 |
+
block = []
|
87 |
+
|
88 |
+
block.append(nn.Conv3d(in_size, out_size, kernel_size=3,
|
89 |
+
padding=int(padding)))
|
90 |
+
block.append(nn.ReLU())
|
91 |
+
if batch_norm:
|
92 |
+
block.append(nn.BatchNorm3d(out_size))
|
93 |
+
|
94 |
+
block.append(nn.Conv3d(out_size, out_size, kernel_size=3,
|
95 |
+
padding=int(padding)))
|
96 |
+
block.append(nn.ReLU())
|
97 |
+
if batch_norm:
|
98 |
+
block.append(nn.BatchNorm3d(out_size))
|
99 |
+
|
100 |
+
self.block = nn.Sequential(*block)
|
101 |
+
|
102 |
+
def forward(self, x):
|
103 |
+
out = self.block(x)
|
104 |
+
return out
|
105 |
+
|
106 |
+
|
107 |
+
class UNetUpBlock(nn.Module):
|
108 |
+
def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
|
109 |
+
super(UNetUpBlock, self).__init__()
|
110 |
+
if up_mode == 'upconv':
|
111 |
+
self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=2,
|
112 |
+
stride=2)
|
113 |
+
elif up_mode == 'upsample':
|
114 |
+
self.up = nn.Sequential(nn.Upsample(mode='trilinear', scale_factor=2),
|
115 |
+
nn.Conv3d(in_size, out_size, kernel_size=1))
|
116 |
+
|
117 |
+
self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)
|
118 |
+
|
119 |
+
def center_crop(self, layer, target_size):
|
120 |
+
_, _, layer_depth, layer_height, layer_width = layer.size()
|
121 |
+
diff_z = (layer_depth - target_size[0]) // 2
|
122 |
+
diff_y = (layer_height - target_size[1]) // 2
|
123 |
+
diff_x = (layer_width - target_size[2]) // 2
|
124 |
+
return layer[:, :, diff_z:(diff_z + target_size[0]), diff_y:(diff_y + target_size[1]), diff_x:(diff_x + target_size[2])]
|
125 |
+
# _, _, layer_height, layer_width = layer.size() #for 2D data
|
126 |
+
# diff_y = (layer_height - target_size[0]) // 2
|
127 |
+
# diff_x = (layer_width - target_size[1]) // 2
|
128 |
+
# return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])]
|
129 |
+
|
130 |
+
def forward(self, x, bridge):
|
131 |
+
up = self.up(x)
|
132 |
+
# bridge = self.center_crop(bridge, up.shape[2:]) #sending shape ignoring 2 digit, so target size start with 0,1,2
|
133 |
+
up = F.interpolate(up, size=bridge.shape[2:], mode='trilinear')
|
134 |
+
out = torch.cat([up, bridge], 1)
|
135 |
+
out = self.conv_block(out)
|
136 |
+
|
137 |
+
return out
|