import torch import torch.nn as nn import torch.nn.functional as F class DeepLabV3Decoder(nn.Sequential): def __init__(self, in_channels, out_channels=256, atrous_rates=(12, 24, 36)): super().__init__( ASPP(in_channels, out_channels, atrous_rates), nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=True), nn.ReLU(), ) self.out_channels = out_channels def forward(self, *features): return super().forward(features[-1]) class ASPPConv(nn.Sequential): def __init__(self, in_channels, out_channels, dilation): super().__init__( nn.Conv2d( in_channels, out_channels, kernel_size=3, padding=dilation, dilation=dilation, bias=True, ), nn.ReLU(), ) class ASPPSeparableConv(nn.Sequential): def __init__(self, in_channels, out_channels, dilation): super().__init__( SeparableConv2d( in_channels, out_channels, kernel_size=3, padding=dilation, dilation=dilation, bias=True, ), nn.ReLU(), ) class ASPPPooling(nn.Sequential): def __init__(self, in_channels, out_channels): super().__init__( nn.AdaptiveAvgPool2d(1), nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True), nn.ReLU(), ) def forward(self, x): size = x.shape[-2:] for mod in self: x = mod(x) return F.interpolate(x, size=size, mode="bilinear", align_corners=False) class ASPP(nn.Module): def __init__(self, in_channels, out_channels, atrous_rates, separable=False): super(ASPP, self).__init__() modules = [] modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels, 1, bias=True), nn.ReLU(), ) ) rate1, rate2, rate3 = tuple(atrous_rates) ASPPConvModule = ASPPConv if not separable else ASPPSeparableConv modules.append(ASPPConvModule(in_channels, out_channels, rate1)) modules.append(ASPPConvModule(in_channels, out_channels, rate2)) modules.append(ASPPConvModule(in_channels, out_channels, rate3)) modules.append(ASPPPooling(in_channels, out_channels)) self.convs = nn.ModuleList(modules) self.project = nn.Sequential( nn.Conv2d(5 * out_channels, out_channels, kernel_size=1, bias=True), nn.ReLU(), ) def forward(self, x): res = [] for conv in self.convs: res.append(conv(x)) res = torch.cat(res, dim=1) return self.project(res) class SeparableConv2d(nn.Sequential): def __init__( self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, bias=True, ): dephtwise_conv = nn.Conv2d( in_channels, in_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=in_channels, bias=False, ) pointwise_conv = nn.Conv2d( in_channels, out_channels, kernel_size=1, bias=bias, ) super().__init__(dephtwise_conv, pointwise_conv)