Spaces:
Runtime error
Runtime error
""" | |
C. HAN, C. WU, H. GUO, M. HU, AND H. CHEN, | |
“HANET: A HIERARCHICAL ATTENTION NETWORK FOR CHANGE DETECTION WITH BI-TEMPORAL VERY-HIGH-RESOLUTION REMOTE SENSING IMAGES,” | |
IEEE J. SEL. TOP. APPL. EARTH OBS. REMOTE SENS., PP. 1-17, 2023, DOI: 10.1109/JSTARS.2023.3264802. | |
Some code in this file is borrowed from: | |
https://github.com/ChengxiHAN/HANet-CD/blob/main/models/HANet.py | |
""" | |
import torch | |
import torch.nn as nn | |
from opencd.registry import MODELS | |
class CAM_Module(nn.Module): | |
""" Channel attention module""" | |
def __init__(self, in_dim): | |
super(CAM_Module, self).__init__() | |
self.chanel_in = in_dim | |
self.gamma = nn.Parameter(torch.zeros(1)) | |
self.softmax = nn.Softmax(dim=-1) | |
def forward(self, x): | |
m_batchsize, C, height, width = x.size() | |
proj_query = x.view(m_batchsize, C, -1) | |
proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1) | |
energy = torch.bmm(proj_query, proj_key) | |
energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy | |
attention = self.softmax(energy_new) | |
proj_value = x.view(m_batchsize, C, -1) | |
out = torch.bmm(attention, proj_value) | |
out = out.view(m_batchsize, C, height, width) | |
out = self.gamma * out + x | |
return out | |
class Conv_CAM_Layer(nn.Module): | |
def __init__(self, in_ch, out_in, use_pam=False): | |
super(Conv_CAM_Layer, self).__init__() | |
self.attn = nn.Sequential( | |
nn.Conv2d(in_ch, 32, kernel_size=3, padding=1), | |
nn.BatchNorm2d(32), | |
nn.PReLU(), | |
CAM_Module(32), | |
nn.Conv2d(32, out_in, kernel_size=3, padding=1), | |
nn.BatchNorm2d(out_in), | |
nn.PReLU() | |
) | |
def forward(self, x): | |
return self.attn(x) | |
class FEC(nn.Module): | |
"""feature extraction cell""" | |
#convolutional block | |
def __init__(self, in_ch, mid_ch, out_ch): | |
super(FEC, self).__init__() | |
self.activation = nn.ReLU(inplace=True) | |
self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1,bias=True) | |
self.bn1 = nn.BatchNorm2d(mid_ch) | |
self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=1, stride=1, bias=False) | |
self.bn2 = nn.BatchNorm2d(out_ch) | |
def forward(self, x): | |
x = self.conv1(x) | |
identity = x | |
x = self.bn1(x) | |
x = self.activation(x) | |
x = self.conv2(x) | |
x = self.bn2(x) | |
output = self.activation(x + identity) | |
return output | |
class RowAttention(nn.Module): | |
def __init__(self, in_dim, q_k_dim, use_pam=False): | |
''' | |
Parameters | |
---------- | |
in_dim : int | |
channel of input img tensor | |
q_k_dim: int | |
channel of Q, K vector | |
device : torch.device | |
''' | |
super(RowAttention, self).__init__() | |
self.in_dim = in_dim | |
self.q_k_dim = q_k_dim | |
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.q_k_dim, kernel_size=1) | |
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.q_k_dim, kernel_size=1) | |
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.in_dim, kernel_size=1) | |
self.softmax = nn.Softmax(dim=2) | |
self.gamma = nn.Parameter(torch.zeros(1)) | |
def forward(self, x): | |
''' | |
Parameters | |
---------- | |
x : Tensor | |
4-D , (batch, in_dims, height, width) -- (b,c1,h,w) | |
''' | |
b, _, h, w = x.size() | |
Q = self.query_conv(x) # size = (b,c2, h,w) | |
K = self.key_conv(x) # size = (b, c2, h, w) | |
V = self.value_conv(x) # size = (b, c1,h,w) | |
Q = Q.permute(0, 2, 1, 3).contiguous().view(b * h, -1, w).permute(0, 2, 1) # size = (b*h,w,c2) | |
K = K.permute(0, 2, 1, 3).contiguous().view(b * h, -1, w) # size = (b*h,c2,w) | |
V = V.permute(0, 2, 1, 3).contiguous().view(b * h, -1, w) # size = (b*h, c1,w) | |
row_attn = torch.bmm(Q, K) | |
row_attn = self.softmax(row_attn) | |
out = torch.bmm(V, row_attn.permute(0, 2, 1)) | |
out = out.view(b, h, -1, w).permute(0, 2, 1, 3) | |
out = self.gamma * out + x | |
return out | |
class ColAttention(nn.Module): | |
def __init__(self, in_dim, q_k_dim, use_pam=False): | |
''' | |
Parameters | |
---------- | |
in_dim : int | |
channel of input img tensor | |
q_k_dim: int | |
channel of Q, K vector | |
device : torch.device | |
''' | |
super(ColAttention, self).__init__() | |
self.in_dim = in_dim | |
self.q_k_dim = q_k_dim | |
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.q_k_dim, kernel_size=1) | |
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.q_k_dim, kernel_size=1) | |
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.in_dim, kernel_size=1) | |
self.softmax = nn.Softmax(dim=2) | |
self.gamma = nn.Parameter(torch.zeros(1)) | |
def forward(self, x): | |
''' | |
Parameters | |
---------- | |
x : Tensor | |
4-D , (batch, in_dims, height, width) -- (b,c1,h,w) | |
''' | |
b, _, h, w = x.size() | |
Q = self.query_conv(x) # size = (b,c2, h,w) | |
K = self.key_conv(x) # size = (b, c2, h, w) | |
V = self.value_conv(x) # size = (b, c1,h,w) | |
Q = Q.permute(0, 3, 1, 2).contiguous().view(b * w, -1, h).permute(0, 2, 1) # size = (b*w,h,c2) | |
K = K.permute(0, 3, 1, 2).contiguous().view(b * w, -1, h) # size = (b*w,c2,h) | |
V = V.permute(0, 3, 1, 2).contiguous().view(b * w, -1, h) # size = (b*w,c1,h) | |
# size = (b*w,h,h) [:,i,j] | |
col_attn = torch.bmm(Q, K) | |
col_attn = self.softmax(col_attn) | |
out = torch.bmm(V, col_attn.permute(0, 2, 1)) | |
# size = (b,c1,h,w) | |
out = out.view(b, w, -1, h).permute(0, 2, 3, 1) | |
out = self.gamma * out + x | |
return out | |
class HAN(nn.Module): | |
"""HANet""" | |
def __init__(self, in_channels, base_channel=40): | |
super(HAN, self).__init__() | |
torch.nn.Module.dump_patches = True | |
n1 = base_channel # the initial number of channels of feature map | |
filters = [n1, n1 * 2, n1 * 4, n1 * 8] | |
self.conv0_0 = nn.Conv2d(in_channels, n1, kernel_size=5, padding=2, stride=1) | |
self.conv0 = FEC(filters[0], filters[0], filters[0]) | |
self.conv2 = FEC(filters[0], filters[1], filters[1]) | |
self.conv4 = FEC(filters[1], filters[2], filters[2]) | |
self.conv5 = FEC(filters[2], filters[3], filters[3]) | |
self.conv6 = nn.Conv2d(sum(filters), filters[1], kernel_size=1, stride=1) | |
self.conv6_1_1 = nn.Conv2d(filters[0] * 2, filters[0], padding=1, kernel_size=3, groups=filters[0] // 2,dilation=1) | |
self.conv6_1_2 = nn.Conv2d(filters[0] * 2, filters[0], padding=2, kernel_size=3, groups=filters[0] // 2,dilation=2) | |
self.conv6_1_3 = nn.Conv2d(filters[0] * 2, filters[0], padding=3, kernel_size=3, groups=filters[0] // 2,dilation=3) | |
self.conv6_1_4 = nn.Conv2d(filters[0] * 2, filters[0], padding=4, kernel_size=3, groups=filters[0] // 2,dilation=4) | |
self.conv1_1 = nn.Conv2d(filters[0] * 4, filters[0], kernel_size=1, stride=1) | |
self.conv6_2_1 = nn.Conv2d(filters[1] * 2, filters[1], padding=1, kernel_size=3, groups=filters[1] // 2, dilation=1) | |
self.conv6_2_2 = nn.Conv2d(filters[1] * 2, filters[1], padding=2, kernel_size=3, groups=filters[1] // 2, dilation=2) | |
self.conv6_2_3 = nn.Conv2d(filters[1] * 2, filters[1], padding=3, kernel_size=3, groups=filters[1] // 2, dilation=3) | |
self.conv6_2_4 = nn.Conv2d(filters[1] * 2, filters[1], padding=4, kernel_size=3, groups=filters[1] // 2, dilation=4) | |
self.conv2_1 = nn.Conv2d(filters[1] * 4, filters[1], kernel_size=1, stride=1) | |
self.conv6_3_1 = nn.Conv2d(filters[2] * 2, filters[2], padding=1, kernel_size=3, groups=filters[2] // 2, dilation=1) | |
self.conv6_3_2 = nn.Conv2d(filters[2] * 2, filters[2], padding=2, kernel_size=3, groups=filters[2] // 2, dilation=2) | |
self.conv6_3_3 = nn.Conv2d(filters[2] * 2, filters[2], padding=3, kernel_size=3, groups=filters[2] // 2, dilation=3) | |
self.conv6_3_4 = nn.Conv2d(filters[2] * 2, filters[2], padding=4, kernel_size=3, groups=filters[2] // 2, dilation=4) | |
self.conv3_1 = nn.Conv2d(filters[2] * 4, filters[2], kernel_size=1, stride=1) | |
self.conv6_4_1 = nn.Conv2d(filters[3]*2, filters[3], padding=1, kernel_size=3, groups=filters[3]//2, dilation=1) | |
self.conv6_4_2 = nn.Conv2d(filters[3]*2, filters[3], padding=2, kernel_size=3, groups=filters[3]//2, dilation=2) | |
self.conv6_4_3 = nn.Conv2d(filters[3]*2, filters[3], padding=3, kernel_size=3, groups=filters[3]//2, dilation=3) | |
self.conv6_4_4 = nn.Conv2d(filters[3]*2, filters[3], padding=4, kernel_size=3, groups=filters[3]//2, dilation=4) | |
self.conv4_1 = nn.Conv2d(filters[3]*4, filters[3], kernel_size=1, stride=1) | |
# SA | |
self.cam_attention_1 = Conv_CAM_Layer(filters[0], filters[0], False) #SA4 | |
self.cam_attention_2 = Conv_CAM_Layer(filters[1], filters[1], False) #SA3 | |
self.cam_attention_3 = Conv_CAM_Layer(filters[2], filters[2], False) #SA2 | |
self.cam_attention_4 = Conv_CAM_Layer(filters[3], filters[3], False) #SA1 | |
#Row Attention | |
self.row_attention_1 = RowAttention(filters[0], filters[0], False) # SA4 | |
self.row_attention_2 = RowAttention(filters[1], filters[1], False) # SA3 | |
self.row_attention_3 = RowAttention(filters[2], filters[2], False) # SA2 | |
self.row_attention_4 = RowAttention(filters[3], filters[3], False) # SA1 | |
# Col Attention | |
self.col_attention_1 = ColAttention(filters[0], filters[0], False) # SA4 | |
self.col_attention_2 = ColAttention(filters[1], filters[1], False) # SA3 | |
self.col_attention_3 = ColAttention(filters[2], filters[2], False) # SA2 | |
self.col_attention_4 = ColAttention(filters[3], filters[3], False) # SA1 | |
self.c4_conv = nn.Conv2d(filters[3], filters[1], kernel_size=3, padding=1) | |
self.c3_conv = nn.Conv2d(filters[2], filters[1], kernel_size=3, padding=1) | |
self.c2_conv = nn.Conv2d(filters[1], filters[1], kernel_size=3, padding=1) | |
self.c1_conv = nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1) | |
self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) | |
self.Up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) | |
self.Up2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False) | |
self.Up3 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=False) | |
for m in self.modules(): | |
if isinstance(m, nn.Conv2d): | |
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | |
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | |
nn.init.constant_(m.weight, 1) | |
nn.init.constant_(m.bias, 0) | |
def forward(self, x1, x2): | |
x1 = self.conv0(self.conv0_0(x1)) # Output of the first scale | |
x3 = self.conv2(self.pool(x1)) | |
x4 = self.conv4(self.pool(x3)) | |
A_F4 = self.conv5(self.pool(x4)) | |
x2 = self.conv0(self.conv0_0(x2)) | |
x5 = self.conv2(self.pool(x2)) | |
x6 = self.conv4(self.pool(x5)) | |
A_F8 = self.conv5(self.pool(x6)) | |
c4_1 = self.conv4_1( | |
torch.cat([self.conv6_4_1(torch.cat([A_F4, A_F8], 1)), self.conv6_4_2(torch.cat([A_F4, A_F8], 1)), | |
self.conv6_4_3(torch.cat([A_F4, A_F8], 1)), self.conv6_4_4(torch.cat([A_F4, A_F8], 1))], 1)) | |
c4 = self.cam_attention_4(c4_1) + self.row_attention_4(self.col_attention_4(c4_1)) | |
c3_1 = (self.conv3_1(torch.cat( | |
[self.conv6_3_1(torch.cat([x4, x6], 1)), self.conv6_3_2(torch.cat([x4, x6], 1)), | |
self.conv6_3_3(torch.cat([x4, x6], 1)), self.conv6_3_4(torch.cat([x4, x6], 1))], 1))) | |
c3 = torch.cat([(self.cam_attention_3(c3_1)+self.row_attention_3(self.col_attention_3(c3_1))), self.Up1(c4)], 1) | |
c2_1 = (self.conv2_1(torch.cat( | |
[self.conv6_2_1(torch.cat([x3, x5], 1)), self.conv6_2_2(torch.cat([x3, x5], 1)), | |
self.conv6_2_3(torch.cat([x3, x5], 1)), self.conv6_2_4(torch.cat([x3, x5], 1))], 1))) | |
c2 = torch.cat([(self.cam_attention_2(c2_1)+self.row_attention_2(self.col_attention_2(c2_1))), self.Up1(c3)], 1) | |
c1_1 = (self.conv1_1(torch.cat( | |
[self.conv6_1_1(torch.cat([x1, x2], 1)), self.conv6_1_2(torch.cat([x1, x2], 1)), | |
self.conv6_1_3(torch.cat([x1, x2], 1)), self.conv6_1_4(torch.cat([x1, x2], 1))], 1))) | |
c1 = torch.cat([(self.cam_attention_1(c1_1)+self.row_attention_1(self.col_attention_1(c1_1))), self.Up1(c2)], 1) | |
out1 = self.conv6(c1) | |
return (out1, ) | |