File size: 5,836 Bytes
abd2a81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from collections import OrderedDict
from torch import nn
from torch.nn import functional as F
import torch
import torchvision
def conv3x3(in_, out):
return nn.Conv2d(in_, out, 3, padding=1)
class ConvRelu(nn.Module):
def __init__(self, in_, out):
super().__init__()
self.conv = conv3x3(in_, out)
self.activation = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.activation(x)
return x
class DecoderBlockV2(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
super(DecoderBlockV2, self).__init__()
self.in_channels = in_channels
if is_deconv:
"""
Parameters for Deconvolution were chosen to avoid artifacts, following
link https://distill.pub/2016/deconv-checkerboard/
"""
self.block = nn.Sequential(
ConvRelu(in_channels, middle_channels),
nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2,
padding=1),
nn.ReLU(inplace=True)
)
else:
self.block = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(in_channels, middle_channels, 3, padding=1, bias=True),
nn.BatchNorm2d(middle_channels),
nn.ELU(),
nn.Conv2d(middle_channels, out_channels, 3, padding=1, bias=True),
nn.BatchNorm2d(out_channels),
nn.ELU()
)
def forward(self, x):
return self.block(x)
def cat_non_matching(x1, x2):
diffY = x1.size()[2] - x2.size()[2]
diffX = x1.size()[3] - x2.size()[3]
x2 = F.pad(x2, (diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2))
# for padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x1, x2], dim=1)
return x
class UNetResNetBackbone(nn.Module):
"""PyTorch U-Net model using ResNet(34, 101 or 152) encoder.
UNet: https://arxiv.org/abs/1505.04597
ResNet: https://arxiv.org/abs/1512.03385
Proposed by Alexander Buslaev: https://www.linkedin.com/in/al-buslaev/
Args:
encoder_depth (int): Depth of a ResNet encoder (34, 101 or 152).
num_filters (int, optional): Number of filters in the last layer of decoder. Defaults to 32.
dropout_2d (float, optional): Probability factor of dropout layer before output layer. Defaults to 0.2.
pretrained (bool, optional):
False - no pre-trained weights are being used.
True - ResNet encoder is pre-trained on ImageNet.
Defaults to False.
is_deconv (bool, optional):
False: bilinear interpolation is used in decoder.
True: deconvolution is used in decoder.
Defaults to False.
"""
def __init__(self, encoder_depth, num_filters=32, dropout_2d=0.2,
pretrained=False, is_deconv=False):
super().__init__()
self.dropout_2d = dropout_2d
if encoder_depth == 34:
self.encoder = torchvision.models.resnet34(pretrained=pretrained)
bottom_channel_nr = 512
elif encoder_depth == 101:
self.encoder = torchvision.models.resnet101(pretrained=pretrained)
bottom_channel_nr = 2048
elif encoder_depth == 152:
self.encoder = torchvision.models.resnet152(pretrained=pretrained)
bottom_channel_nr = 2048
else:
raise NotImplementedError('only 34, 101, 152 version of ResNet are implemented')
self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(self.encoder.conv1,
self.encoder.bn1,
self.encoder.relu,
self.pool)
self.conv2 = self.encoder.layer1
self.conv3 = self.encoder.layer2
self.conv4 = self.encoder.layer3
self.conv5 = self.encoder.layer4
self.center = DecoderBlockV2(bottom_channel_nr, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec5 = DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8,
is_deconv)
self.dec3 = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2,
is_deconv)
self.dec2 = DecoderBlockV2(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
is_deconv)
self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
conv5 = self.conv5(conv4)
pool = self.pool(conv5)
center = self.center(pool)
dec5 = self.dec5(cat_non_matching(conv5, center))
dec4 = self.dec4(cat_non_matching(conv4, dec5))
dec3 = self.dec3(cat_non_matching(conv3, dec4))
dec2 = self.dec2(cat_non_matching(conv2, dec3))
dec1 = self.dec1(dec2)
y = F.dropout2d(dec1, p=self.dropout_2d)
result = OrderedDict()
result["out"] = y
return result
|