Spaces:
Runtime error
Runtime error
File size: 6,321 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
"""
S. Fang, K. Li, J. Shao, and Z. Li,
“SNUNet-CD: A Densely Connected Siamese Network for Change Detection of VHR Images,”
IEEE Geosci. Remote Sensing Lett., pp. 1-5, 2021, doi: 10.1109/LGRS.2021.3056416.
"""
import torch
import torch.nn as nn
from opencd.registry import MODELS
class conv_block_nested(nn.Module):
def __init__(self, in_ch, mid_ch, out_ch):
super(conv_block_nested, self).__init__()
self.activation = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1, bias=True)
self.bn1 = nn.BatchNorm2d(mid_ch)
self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1, bias=True)
self.bn2 = nn.BatchNorm2d(out_ch)
def forward(self, x):
x = self.conv1(x)
identity = x
x = self.bn1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.bn2(x)
output = self.activation(x + identity)
return output
class up(nn.Module):
def __init__(self, in_ch, bilinear=False):
super(up, self).__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2,
mode='bilinear',
align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_ch, in_ch, 2, stride=2)
def forward(self, x):
x = self.up(x)
return x
class ChannelAttention(nn.Module):
def __init__(self, in_channels, ratio = 16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(in_channels,in_channels//ratio,1,bias=False)
self.relu1 = nn.ReLU()
self.fc2 = nn.Conv2d(in_channels//ratio, in_channels,1,bias=False)
self.sigmod = nn.Sigmoid()
def forward(self,x):
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
out = avg_out + max_out
return self.sigmod(out)
@MODELS.register_module()
class SNUNet_ECAM(nn.Module):
# SNUNet-CD with ECAM
def __init__(self, in_channels, base_channel=32):
super(SNUNet_ECAM, self).__init__()
torch.nn.Module.dump_patches = True
n1 = base_channel # the initial number of channels of feature map
filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv0_0 = conv_block_nested(in_channels, filters[0], filters[0])
self.conv1_0 = conv_block_nested(filters[0], filters[1], filters[1])
self.Up1_0 = up(filters[1])
self.conv2_0 = conv_block_nested(filters[1], filters[2], filters[2])
self.Up2_0 = up(filters[2])
self.conv3_0 = conv_block_nested(filters[2], filters[3], filters[3])
self.Up3_0 = up(filters[3])
self.conv4_0 = conv_block_nested(filters[3], filters[4], filters[4])
self.Up4_0 = up(filters[4])
self.conv0_1 = conv_block_nested(filters[0] * 2 + filters[1], filters[0], filters[0])
self.conv1_1 = conv_block_nested(filters[1] * 2 + filters[2], filters[1], filters[1])
self.Up1_1 = up(filters[1])
self.conv2_1 = conv_block_nested(filters[2] * 2 + filters[3], filters[2], filters[2])
self.Up2_1 = up(filters[2])
self.conv3_1 = conv_block_nested(filters[3] * 2 + filters[4], filters[3], filters[3])
self.Up3_1 = up(filters[3])
self.conv0_2 = conv_block_nested(filters[0] * 3 + filters[1], filters[0], filters[0])
self.conv1_2 = conv_block_nested(filters[1] * 3 + filters[2], filters[1], filters[1])
self.Up1_2 = up(filters[1])
self.conv2_2 = conv_block_nested(filters[2] * 3 + filters[3], filters[2], filters[2])
self.Up2_2 = up(filters[2])
self.conv0_3 = conv_block_nested(filters[0] * 4 + filters[1], filters[0], filters[0])
self.conv1_3 = conv_block_nested(filters[1] * 4 + filters[2], filters[1], filters[1])
self.Up1_3 = up(filters[1])
self.conv0_4 = conv_block_nested(filters[0] * 5 + filters[1], filters[0], filters[0])
self.ca = ChannelAttention(filters[0] * 4, ratio=16)
self.ca1 = ChannelAttention(filters[0], ratio=16 // 4)
# self.conv_final = nn.Conv2d(filters[0] * 4, out_ch, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, xA, xB):
'''xA'''
x0_0A = self.conv0_0(xA)
x1_0A = self.conv1_0(self.pool(x0_0A))
x2_0A = self.conv2_0(self.pool(x1_0A))
x3_0A = self.conv3_0(self.pool(x2_0A))
# x4_0A = self.conv4_0(self.pool(x3_0A))
'''xB'''
x0_0B = self.conv0_0(xB)
x1_0B = self.conv1_0(self.pool(x0_0B))
x2_0B = self.conv2_0(self.pool(x1_0B))
x3_0B = self.conv3_0(self.pool(x2_0B))
x4_0B = self.conv4_0(self.pool(x3_0B))
x0_1 = self.conv0_1(torch.cat([x0_0A, x0_0B, self.Up1_0(x1_0B)], 1))
x1_1 = self.conv1_1(torch.cat([x1_0A, x1_0B, self.Up2_0(x2_0B)], 1))
x0_2 = self.conv0_2(torch.cat([x0_0A, x0_0B, x0_1, self.Up1_1(x1_1)], 1))
x2_1 = self.conv2_1(torch.cat([x2_0A, x2_0B, self.Up3_0(x3_0B)], 1))
x1_2 = self.conv1_2(torch.cat([x1_0A, x1_0B, x1_1, self.Up2_1(x2_1)], 1))
x0_3 = self.conv0_3(torch.cat([x0_0A, x0_0B, x0_1, x0_2, self.Up1_2(x1_2)], 1))
x3_1 = self.conv3_1(torch.cat([x3_0A, x3_0B, self.Up4_0(x4_0B)], 1))
x2_2 = self.conv2_2(torch.cat([x2_0A, x2_0B, x2_1, self.Up3_1(x3_1)], 1))
x1_3 = self.conv1_3(torch.cat([x1_0A, x1_0B, x1_1, x1_2, self.Up2_2(x2_2)], 1))
x0_4 = self.conv0_4(torch.cat([x0_0A, x0_0B, x0_1, x0_2, x0_3, self.Up1_3(x1_3)], 1))
out = torch.cat([x0_1, x0_2, x0_3, x0_4], 1)
intra = torch.sum(torch.stack((x0_1, x0_2, x0_3, x0_4)), dim=0)
ca1 = self.ca1(intra)
out = self.ca(out) * (out + ca1.repeat(1, 4, 1, 1))
# out = self.conv_final(out)
return (out, )
|