KyanChen's picture
Upload 1861 files
3b96cb1
"""
C. HAN, C. WU, H. GUO, M. HU, AND H. CHEN,
“HANET: A HIERARCHICAL ATTENTION NETWORK FOR CHANGE DETECTION WITH BI-TEMPORAL VERY-HIGH-RESOLUTION REMOTE SENSING IMAGES,”
IEEE J. SEL. TOP. APPL. EARTH OBS. REMOTE SENS., PP. 1-17, 2023, DOI: 10.1109/JSTARS.2023.3264802.
Some code in this file is borrowed from:
https://github.com/ChengxiHAN/HANet-CD/blob/main/models/HANet.py
"""
import torch
import torch.nn as nn
from opencd.registry import MODELS
class CAM_Module(nn.Module):
""" Channel attention module"""
def __init__(self, in_dim):
super(CAM_Module, self).__init__()
self.chanel_in = in_dim
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
m_batchsize, C, height, width = x.size()
proj_query = x.view(m_batchsize, C, -1)
proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
energy = torch.bmm(proj_query, proj_key)
energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy) - energy
attention = self.softmax(energy_new)
proj_value = x.view(m_batchsize, C, -1)
out = torch.bmm(attention, proj_value)
out = out.view(m_batchsize, C, height, width)
out = self.gamma * out + x
return out
class Conv_CAM_Layer(nn.Module):
def __init__(self, in_ch, out_in, use_pam=False):
super(Conv_CAM_Layer, self).__init__()
self.attn = nn.Sequential(
nn.Conv2d(in_ch, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.PReLU(),
CAM_Module(32),
nn.Conv2d(32, out_in, kernel_size=3, padding=1),
nn.BatchNorm2d(out_in),
nn.PReLU()
)
def forward(self, x):
return self.attn(x)
class FEC(nn.Module):
"""feature extraction cell"""
#convolutional block
def __init__(self, in_ch, mid_ch, out_ch):
super(FEC, self).__init__()
self.activation = nn.ReLU(inplace=True)
self.conv1 = nn.Conv2d(in_ch, mid_ch, kernel_size=3, padding=1,bias=True)
self.bn1 = nn.BatchNorm2d(mid_ch)
self.conv2 = nn.Conv2d(mid_ch, out_ch, kernel_size=1, stride=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_ch)
def forward(self, x):
x = self.conv1(x)
identity = x
x = self.bn1(x)
x = self.activation(x)
x = self.conv2(x)
x = self.bn2(x)
output = self.activation(x + identity)
return output
class RowAttention(nn.Module):
def __init__(self, in_dim, q_k_dim, use_pam=False):
'''
Parameters
----------
in_dim : int
channel of input img tensor
q_k_dim: int
channel of Q, K vector
device : torch.device
'''
super(RowAttention, self).__init__()
self.in_dim = in_dim
self.q_k_dim = q_k_dim
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.q_k_dim, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.q_k_dim, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.in_dim, kernel_size=1)
self.softmax = nn.Softmax(dim=2)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
'''
Parameters
----------
x : Tensor
4-D , (batch, in_dims, height, width) -- (b,c1,h,w)
'''
b, _, h, w = x.size()
Q = self.query_conv(x) # size = (b,c2, h,w)
K = self.key_conv(x) # size = (b, c2, h, w)
V = self.value_conv(x) # size = (b, c1,h,w)
Q = Q.permute(0, 2, 1, 3).contiguous().view(b * h, -1, w).permute(0, 2, 1) # size = (b*h,w,c2)
K = K.permute(0, 2, 1, 3).contiguous().view(b * h, -1, w) # size = (b*h,c2,w)
V = V.permute(0, 2, 1, 3).contiguous().view(b * h, -1, w) # size = (b*h, c1,w)
row_attn = torch.bmm(Q, K)
row_attn = self.softmax(row_attn)
out = torch.bmm(V, row_attn.permute(0, 2, 1))
out = out.view(b, h, -1, w).permute(0, 2, 1, 3)
out = self.gamma * out + x
return out
class ColAttention(nn.Module):
def __init__(self, in_dim, q_k_dim, use_pam=False):
'''
Parameters
----------
in_dim : int
channel of input img tensor
q_k_dim: int
channel of Q, K vector
device : torch.device
'''
super(ColAttention, self).__init__()
self.in_dim = in_dim
self.q_k_dim = q_k_dim
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.q_k_dim, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.q_k_dim, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=self.in_dim, kernel_size=1)
self.softmax = nn.Softmax(dim=2)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
'''
Parameters
----------
x : Tensor
4-D , (batch, in_dims, height, width) -- (b,c1,h,w)
'''
b, _, h, w = x.size()
Q = self.query_conv(x) # size = (b,c2, h,w)
K = self.key_conv(x) # size = (b, c2, h, w)
V = self.value_conv(x) # size = (b, c1,h,w)
Q = Q.permute(0, 3, 1, 2).contiguous().view(b * w, -1, h).permute(0, 2, 1) # size = (b*w,h,c2)
K = K.permute(0, 3, 1, 2).contiguous().view(b * w, -1, h) # size = (b*w,c2,h)
V = V.permute(0, 3, 1, 2).contiguous().view(b * w, -1, h) # size = (b*w,c1,h)
# size = (b*w,h,h) [:,i,j]
col_attn = torch.bmm(Q, K)
col_attn = self.softmax(col_attn)
out = torch.bmm(V, col_attn.permute(0, 2, 1))
# size = (b,c1,h,w)
out = out.view(b, w, -1, h).permute(0, 2, 3, 1)
out = self.gamma * out + x
return out
@MODELS.register_module()
class HAN(nn.Module):
"""HANet"""
def __init__(self, in_channels, base_channel=40):
super(HAN, self).__init__()
torch.nn.Module.dump_patches = True
n1 = base_channel # the initial number of channels of feature map
filters = [n1, n1 * 2, n1 * 4, n1 * 8]
self.conv0_0 = nn.Conv2d(in_channels, n1, kernel_size=5, padding=2, stride=1)
self.conv0 = FEC(filters[0], filters[0], filters[0])
self.conv2 = FEC(filters[0], filters[1], filters[1])
self.conv4 = FEC(filters[1], filters[2], filters[2])
self.conv5 = FEC(filters[2], filters[3], filters[3])
self.conv6 = nn.Conv2d(sum(filters), filters[1], kernel_size=1, stride=1)
self.conv6_1_1 = nn.Conv2d(filters[0] * 2, filters[0], padding=1, kernel_size=3, groups=filters[0] // 2,dilation=1)
self.conv6_1_2 = nn.Conv2d(filters[0] * 2, filters[0], padding=2, kernel_size=3, groups=filters[0] // 2,dilation=2)
self.conv6_1_3 = nn.Conv2d(filters[0] * 2, filters[0], padding=3, kernel_size=3, groups=filters[0] // 2,dilation=3)
self.conv6_1_4 = nn.Conv2d(filters[0] * 2, filters[0], padding=4, kernel_size=3, groups=filters[0] // 2,dilation=4)
self.conv1_1 = nn.Conv2d(filters[0] * 4, filters[0], kernel_size=1, stride=1)
self.conv6_2_1 = nn.Conv2d(filters[1] * 2, filters[1], padding=1, kernel_size=3, groups=filters[1] // 2, dilation=1)
self.conv6_2_2 = nn.Conv2d(filters[1] * 2, filters[1], padding=2, kernel_size=3, groups=filters[1] // 2, dilation=2)
self.conv6_2_3 = nn.Conv2d(filters[1] * 2, filters[1], padding=3, kernel_size=3, groups=filters[1] // 2, dilation=3)
self.conv6_2_4 = nn.Conv2d(filters[1] * 2, filters[1], padding=4, kernel_size=3, groups=filters[1] // 2, dilation=4)
self.conv2_1 = nn.Conv2d(filters[1] * 4, filters[1], kernel_size=1, stride=1)
self.conv6_3_1 = nn.Conv2d(filters[2] * 2, filters[2], padding=1, kernel_size=3, groups=filters[2] // 2, dilation=1)
self.conv6_3_2 = nn.Conv2d(filters[2] * 2, filters[2], padding=2, kernel_size=3, groups=filters[2] // 2, dilation=2)
self.conv6_3_3 = nn.Conv2d(filters[2] * 2, filters[2], padding=3, kernel_size=3, groups=filters[2] // 2, dilation=3)
self.conv6_3_4 = nn.Conv2d(filters[2] * 2, filters[2], padding=4, kernel_size=3, groups=filters[2] // 2, dilation=4)
self.conv3_1 = nn.Conv2d(filters[2] * 4, filters[2], kernel_size=1, stride=1)
self.conv6_4_1 = nn.Conv2d(filters[3]*2, filters[3], padding=1, kernel_size=3, groups=filters[3]//2, dilation=1)
self.conv6_4_2 = nn.Conv2d(filters[3]*2, filters[3], padding=2, kernel_size=3, groups=filters[3]//2, dilation=2)
self.conv6_4_3 = nn.Conv2d(filters[3]*2, filters[3], padding=3, kernel_size=3, groups=filters[3]//2, dilation=3)
self.conv6_4_4 = nn.Conv2d(filters[3]*2, filters[3], padding=4, kernel_size=3, groups=filters[3]//2, dilation=4)
self.conv4_1 = nn.Conv2d(filters[3]*4, filters[3], kernel_size=1, stride=1)
# SA
self.cam_attention_1 = Conv_CAM_Layer(filters[0], filters[0], False) #SA4
self.cam_attention_2 = Conv_CAM_Layer(filters[1], filters[1], False) #SA3
self.cam_attention_3 = Conv_CAM_Layer(filters[2], filters[2], False) #SA2
self.cam_attention_4 = Conv_CAM_Layer(filters[3], filters[3], False) #SA1
#Row Attention
self.row_attention_1 = RowAttention(filters[0], filters[0], False) # SA4
self.row_attention_2 = RowAttention(filters[1], filters[1], False) # SA3
self.row_attention_3 = RowAttention(filters[2], filters[2], False) # SA2
self.row_attention_4 = RowAttention(filters[3], filters[3], False) # SA1
# Col Attention
self.col_attention_1 = ColAttention(filters[0], filters[0], False) # SA4
self.col_attention_2 = ColAttention(filters[1], filters[1], False) # SA3
self.col_attention_3 = ColAttention(filters[2], filters[2], False) # SA2
self.col_attention_4 = ColAttention(filters[3], filters[3], False) # SA1
self.c4_conv = nn.Conv2d(filters[3], filters[1], kernel_size=3, padding=1)
self.c3_conv = nn.Conv2d(filters[2], filters[1], kernel_size=3, padding=1)
self.c2_conv = nn.Conv2d(filters[1], filters[1], kernel_size=3, padding=1)
self.c1_conv = nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1)
self.pool = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
self.Up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.Up2 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False)
self.Up3 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=False)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x1, x2):
x1 = self.conv0(self.conv0_0(x1)) # Output of the first scale
x3 = self.conv2(self.pool(x1))
x4 = self.conv4(self.pool(x3))
A_F4 = self.conv5(self.pool(x4))
x2 = self.conv0(self.conv0_0(x2))
x5 = self.conv2(self.pool(x2))
x6 = self.conv4(self.pool(x5))
A_F8 = self.conv5(self.pool(x6))
c4_1 = self.conv4_1(
torch.cat([self.conv6_4_1(torch.cat([A_F4, A_F8], 1)), self.conv6_4_2(torch.cat([A_F4, A_F8], 1)),
self.conv6_4_3(torch.cat([A_F4, A_F8], 1)), self.conv6_4_4(torch.cat([A_F4, A_F8], 1))], 1))
c4 = self.cam_attention_4(c4_1) + self.row_attention_4(self.col_attention_4(c4_1))
c3_1 = (self.conv3_1(torch.cat(
[self.conv6_3_1(torch.cat([x4, x6], 1)), self.conv6_3_2(torch.cat([x4, x6], 1)),
self.conv6_3_3(torch.cat([x4, x6], 1)), self.conv6_3_4(torch.cat([x4, x6], 1))], 1)))
c3 = torch.cat([(self.cam_attention_3(c3_1)+self.row_attention_3(self.col_attention_3(c3_1))), self.Up1(c4)], 1)
c2_1 = (self.conv2_1(torch.cat(
[self.conv6_2_1(torch.cat([x3, x5], 1)), self.conv6_2_2(torch.cat([x3, x5], 1)),
self.conv6_2_3(torch.cat([x3, x5], 1)), self.conv6_2_4(torch.cat([x3, x5], 1))], 1)))
c2 = torch.cat([(self.cam_attention_2(c2_1)+self.row_attention_2(self.col_attention_2(c2_1))), self.Up1(c3)], 1)
c1_1 = (self.conv1_1(torch.cat(
[self.conv6_1_1(torch.cat([x1, x2], 1)), self.conv6_1_2(torch.cat([x1, x2], 1)),
self.conv6_1_3(torch.cat([x1, x2], 1)), self.conv6_1_4(torch.cat([x1, x2], 1))], 1)))
c1 = torch.cat([(self.cam_attention_1(c1_1)+self.row_attention_1(self.col_attention_1(c1_1))), self.Up1(c2)], 1)
out1 = self.conv6(c1)
return (out1, )