FunSR / models /baselines /arbrcan.py
KyanChen's picture
add
02c5426
raw
history blame
12 kB
import torch.nn as nn
import torch
import numpy as np
import torch.nn.functional as F
import math
import models
from models import register
def default_conv(in_channels, out_channels, kernel_size, bias=True):
return nn.Conv2d(in_channels, out_channels, kernel_size, padding=(kernel_size//2), bias=bias)
## Channel Attention (CA) Layer
class CALayer(nn.Module):
def __init__(self, channel, reduction=16):
super(CALayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_du = nn.Sequential(
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.conv_du(y)
return x * y
## Residual Channel Attention Block (RCAB)
class RCAB(nn.Module):
def __init__(
self, conv, n_feat, kernel_size, reduction,
bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
super(RCAB, self).__init__()
modules_body = []
for i in range(2):
modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
if bn: modules_body.append(nn.BatchNorm2d(n_feat))
if i == 0: modules_body.append(act)
modules_body.append(CALayer(n_feat, reduction))
self.body = nn.Sequential(*modules_body)
self.res_scale = res_scale
def forward(self, x):
res = self.body(x)
res += x
return res
## Residual Group (RG)
class ResidualGroup(nn.Module):
def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
super(ResidualGroup, self).__init__()
modules_body = [
RCAB(
conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \
for _ in range(n_resblocks)]
modules_body.append(conv(n_feat, n_feat, kernel_size))
self.body = nn.Sequential(*modules_body)
def forward(self, x):
res = self.body(x)
res += x
return res
class SA_upsample(nn.Module):
def __init__(self, channels, num_experts=4, bias=False):
super(SA_upsample, self).__init__()
self.bias = bias
self.num_experts = num_experts
self.channels = channels
# experts
weight_compress = []
for i in range(num_experts):
weight_compress.append(nn.Parameter(torch.Tensor(channels//8, channels, 1, 1)))
nn.init.kaiming_uniform_(weight_compress[i], a=math.sqrt(5))
self.weight_compress = nn.Parameter(torch.stack(weight_compress, 0))
weight_expand = []
for i in range(num_experts):
weight_expand.append(nn.Parameter(torch.Tensor(channels, channels//8, 1, 1)))
nn.init.kaiming_uniform_(weight_expand[i], a=math.sqrt(5))
self.weight_expand = nn.Parameter(torch.stack(weight_expand, 0))
# two FC layers
self.body = nn.Sequential(
nn.Conv2d(4, 64, 1, 1, 0, bias=True),
nn.ReLU(True),
nn.Conv2d(64, 64, 1, 1, 0, bias=True),
nn.ReLU(True),
)
# routing head
self.routing = nn.Sequential(
nn.Conv2d(64, num_experts, 1, 1, 0, bias=True),
nn.Sigmoid()
)
# offset head
self.offset = nn.Conv2d(64, 2, 1, 1, 0, bias=True)
def forward(self, x, scale, scale2):
b, c, h, w = x.size()
# (1) coordinates in LR space
## coordinates in HR space
coor_hr = [torch.arange(0, round(h * scale), 1).unsqueeze(0).float().to(x.device),
torch.arange(0, round(w * scale2), 1).unsqueeze(0).float().to(x.device)]
## coordinates in LR space
coor_h = ((coor_hr[0] + 0.5) / scale) - (torch.floor((coor_hr[0] + 0.5) / scale + 1e-3)) - 0.5
coor_h = coor_h.permute(1, 0)
coor_w = ((coor_hr[1] + 0.5) / scale2) - (torch.floor((coor_hr[1] + 0.5) / scale2 + 1e-3)) - 0.5
input = torch.cat((
torch.ones_like(coor_h).expand([-1, round(scale2 * w)]).unsqueeze(0) / scale2,
torch.ones_like(coor_h).expand([-1, round(scale2 * w)]).unsqueeze(0) / scale,
coor_h.expand([-1, round(scale2 * w)]).unsqueeze(0),
coor_w.expand([round(scale * h), -1]).unsqueeze(0)
), 0).unsqueeze(0)
# (2) predict filters and offsets
embedding = self.body(input)
## offsets
offset = self.offset(embedding)
## filters
routing_weights = self.routing(embedding)
routing_weights = routing_weights.view(self.num_experts, round(scale*h) * round(scale2*w)).transpose(0, 1) # (h*w) * n
weight_compress = self.weight_compress.view(self.num_experts, -1)
weight_compress = torch.matmul(routing_weights, weight_compress)
weight_compress = weight_compress.view(1, round(scale*h), round(scale2*w), self.channels//8, self.channels)
weight_expand = self.weight_expand.view(self.num_experts, -1)
weight_expand = torch.matmul(routing_weights, weight_expand)
weight_expand = weight_expand.view(1, round(scale*h), round(scale2*w), self.channels, self.channels//8)
# (3) grid sample & spatially varying filtering
## grid sample
fea0 = grid_sample(x, offset, scale, scale2) ## b * h * w * c * 1
fea = fea0.unsqueeze(-1).permute(0, 2, 3, 1, 4) ## b * h * w * c * 1
## spatially varying filtering
out = torch.matmul(weight_compress.expand([b, -1, -1, -1, -1]), fea)
out = torch.matmul(weight_expand.expand([b, -1, -1, -1, -1]), out).squeeze(-1)
return out.permute(0, 3, 1, 2) + fea0
class SA_adapt(nn.Module):
def __init__(self, channels):
super(SA_adapt, self).__init__()
self.mask = nn.Sequential(
nn.Conv2d(channels, 16, 3, 1, 1),
nn.BatchNorm2d(16),
nn.ReLU(True),
nn.AvgPool2d(2),
nn.Conv2d(16, 16, 3, 1, 1),
nn.BatchNorm2d(16),
nn.ReLU(True),
nn.Conv2d(16, 16, 3, 1, 1),
nn.BatchNorm2d(16),
nn.ReLU(True),
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(16, 1, 3, 1, 1),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.adapt = SA_conv(channels, channels, 3, 1, 1)
def forward(self, x, scale, scale2):
mask = self.mask(x)
adapted = self.adapt(x, scale, scale2)
return x + adapted * mask
class SA_conv(nn.Module):
def __init__(self, channels_in, channels_out, kernel_size=3, stride=1, padding=1, bias=False, num_experts=4):
super(SA_conv, self).__init__()
self.channels_out = channels_out
self.channels_in = channels_in
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.num_experts = num_experts
self.bias = bias
# FC layers to generate routing weights
self.routing = nn.Sequential(
nn.Linear(2, 64),
nn.ReLU(True),
nn.Linear(64, num_experts),
nn.Softmax(1)
)
# initialize experts
weight_pool = []
for i in range(num_experts):
weight_pool.append(nn.Parameter(torch.Tensor(channels_out, channels_in, kernel_size, kernel_size)))
nn.init.kaiming_uniform_(weight_pool[i], a=math.sqrt(5))
self.weight_pool = nn.Parameter(torch.stack(weight_pool, 0))
if bias:
self.bias_pool = nn.Parameter(torch.Tensor(num_experts, channels_out))
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight_pool)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias_pool, -bound, bound)
def forward(self, x, scale, scale2):
# generate routing weights
scale = torch.ones(1, 1).to(x.device) / scale
scale2 = torch.ones(1, 1).to(x.device) / scale2
routing_weights = self.routing(torch.cat((scale, scale2), 1)).view(self.num_experts, 1, 1)
# fuse experts
fused_weight = (self.weight_pool.view(self.num_experts, -1, 1) * routing_weights).sum(0)
fused_weight = fused_weight.view(-1, self.channels_in, self.kernel_size, self.kernel_size)
if self.bias:
fused_bias = torch.mm(routing_weights, self.bias_pool).view(-1)
else:
fused_bias = None
# convolution
out = F.conv2d(x, fused_weight, fused_bias, stride=self.stride, padding=self.padding)
return out
def grid_sample(x, offset, scale, scale2):
# generate grids
b, _, h, w = x.size()
grid = np.meshgrid(range(round(scale2*w)), range(round(scale*h)))
grid = np.stack(grid, axis=-1).astype(np.float64)
grid = torch.Tensor(grid).to(x.device)
# project into LR space
grid[:, :, 0] = (grid[:, :, 0] + 0.5) / scale2 - 0.5
grid[:, :, 1] = (grid[:, :, 1] + 0.5) / scale - 0.5
# normalize to [-1, 1]
grid[:, :, 0] = grid[:, :, 0] * 2 / (w - 1) -1
grid[:, :, 1] = grid[:, :, 1] * 2 / (h - 1) -1
grid = grid.permute(2, 0, 1).unsqueeze(0)
grid = grid.expand([b, -1, -1, -1])
# add offsets
offset_0 = torch.unsqueeze(offset[:, 0, :, :] * 2 / (w - 1), dim=1)
offset_1 = torch.unsqueeze(offset[:, 1, :, :] * 2 / (h - 1), dim=1)
grid = grid + torch.cat((offset_0, offset_1),1)
grid = grid.permute(0, 2, 3, 1)
# sampling
output = F.grid_sample(x, grid, padding_mode='zeros')
return output
@register('arbrcan')
class ArbRCAN(nn.Module):
def __init__(self, encoder_spec=None, conv=default_conv):
super(ArbRCAN, self).__init__()
n_resgroups = 10
n_resblocks = 20
n_feats = 64
kernel_size = 3
reduction = 16
act = nn.ReLU(True)
n_colors = 3
res_scale = 1
self.n_resgroups = n_resgroups
# head module
modules_head = [conv(n_colors, n_feats, kernel_size)]
self.head = nn.Sequential(*modules_head)
# body module
modules_body = [
ResidualGroup(conv, n_feats, kernel_size, reduction, act=act, res_scale=res_scale,
n_resblocks=n_resblocks) \
for _ in range(n_resgroups)]
modules_body.append(conv(n_feats, n_feats, kernel_size))
self.body = nn.Sequential(*modules_body)
# tail module
modules_tail = [
None, # placeholder to match pre-trained RCAN model
conv(n_feats, n_colors, kernel_size)]
self.tail = nn.Sequential(*modules_tail)
########## our plug-in module ##########
# scale-aware feature adaption block
# For RCAN, feature adaption is performed after each backbone block, i.e., K=1
self.K = 1
sa_adapt = []
for i in range(self.n_resgroups // self.K):
sa_adapt.append(SA_adapt(64))
self.sa_adapt = nn.Sequential(*sa_adapt)
# scale-aware upsampling layer
self.sa_upsample = SA_upsample(64)
def set_scale(self, scale, scale2):
self.scale = scale
self.scale2 = scale2
def forward(self, x, size):
B, C, H, W = x.shape
H_up, W_up = size
scale = H_up / H
scale2 = W_up / W
# head
x = self.head(x)
# body
res = x
for i in range(self.n_resgroups):
res = self.body[i](res)
# scale-aware feature adaption
if (i+1) % self.K == 0:
res = self.sa_adapt[i](res, scale, scale2)
res = self.body[-1](res)
res += x
# scale-aware upsampling
res = self.sa_upsample(res, scale, scale2)
# tail
x = self.tail[1](res)
return x