FunSR / models /baselines /OverNet.py
KyanChen's picture
add
02c5426
raw
history blame
7.26 kB
import torch.nn as nn
import torch
import torch.nn.functional as F
import math
from models import register
class MeanShift(nn.Module):
def __init__(self, mean_rgb, sub):
super(MeanShift, self).__init__()
sign = -1 if sub else 1
r = mean_rgb[0] * sign
g = mean_rgb[1] * sign
b = mean_rgb[2] * sign
self.shifter = nn.Conv2d(3, 3, 1, 1, 0)
self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1)
self.shifter.bias.data = torch.Tensor([r, g, b])
# Freeze the mean shift layer
for params in self.shifter.parameters():
params.requires_grad = False
def forward(self, x):
x = self.shifter(x)
return x
class Scale(nn.Module):
def __init__(self, init_value=1e-3):
super(Scale, self).__init__()
self.scale = nn.Parameter(torch.FloatTensor([init_value]))
def forward(self, input):
return input * self.scale
class SE(nn.Module):
def __init__(self, channel, reduction=16):
super(SE, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Sequential(
nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
nn.Sigmoid()
)
def forward(self, x):
y = self.avg_pool(x)
y = self.conv(y)
return x * y
class ResidualBlock(nn.Module):
def __init__(self,
wn, in_channels, out_channels):
super(ResidualBlock, self).__init__()
self.res_scale = Scale(1)
self.x_scale = Scale(1)
self.SE = SE(64, reduction=16)
body = []
expand = 6
linear = 0.8
body.append(
wn(nn.Conv2d(64, 64*expand, 1, padding=1//2)))
body.append(nn.ReLU(inplace=True))
body.append(
wn(nn.Conv2d(64*expand, int(64*linear), 1, padding=1//2)))
body.append(
wn(nn.Conv2d(int(64*linear), 64, 3, padding=3//2)))
self.body = nn.Sequential(*body)
def forward(self, x):
out = self.body(x)
out = self.SE(out)
out = self.res_scale(out) + self.x_scale(x)
return out
class BasicConv2d(nn.Module):
def __init__(self, wn, in_planes, out_planes, kernel_size, stride, padding=0):
super(BasicConv2d, self).__init__()
self.conv = wn(nn.Conv2d(in_planes, out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, bias=True))
self.LR = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.LR(x)
return x
class UpsampleBlock(nn.Module):
def __init__(self, n_channels, upscale, wn, group=1):
super(UpsampleBlock, self).__init__()
self.up = _UpsampleBlock(n_channels, upscale=upscale, wn=wn, group=group)
def forward(self, x, upscale):
return self.up(x)
class _UpsampleBlock(nn.Module):
def __init__(self, n_channels, upscale, wn, group=1):
super(_UpsampleBlock, self).__init__()
modules = []
if upscale == 2 or upscale == 4 or upscale == 8:
for _ in range(int(math.log(upscale, 2))):
modules += [wn(nn.Conv2d(n_channels, 4 * n_channels, 3, 1, 1, groups=group)),
nn.ReLU(inplace=True)]
modules += [nn.PixelShuffle(2)]
elif upscale == 3:
modules += [wn(nn.Conv2d(n_channels, 9 * n_channels, 3, 1, 1, groups=group)),
nn.ReLU(inplace=True)]
modules += [nn.PixelShuffle(3)]
elif upscale == 5:
modules += [wn(nn.Conv2d(n_channels, 25 * n_channels, 3, 1, 1, groups=group)),
nn.ReLU(inplace=True)]
modules += [nn.PixelShuffle(5)]
self.body = nn.Sequential(*modules)
def forward(self, x):
out = self.body(x)
return out
#Local Dense Groups (LDGs)
class LDGs(nn.Module):
def __init__(self,
in_channels, out_channels, wn,
group=1):
super(LDGs, self).__init__()
self.RB1 = ResidualBlock(wn, in_channels, out_channels)
self.RB2 = ResidualBlock(wn, in_channels, out_channels)
self.RB3 = ResidualBlock(wn, in_channels, out_channels)
self.reduction1 = BasicConv2d(wn, in_channels*2, out_channels, 1, 1, 0)
self.reduction2 = BasicConv2d(wn, in_channels*3, out_channels, 1, 1, 0)
self.reduction3 = BasicConv2d(wn, in_channels*4, out_channels, 1, 1, 0)
def forward(self, x):
c0 = o0 = x
RB1 = self.RB1(o0)
concat1 = torch.cat([c0, RB1], dim=1)
out1 = self.reduction1(concat1)
RB2 = self.RB2(out1)
concat2 = torch.cat([concat1, RB2], dim=1)
out2 = self.reduction2(concat2)
RB3 = self.RB3(out2)
concat3 = torch.cat([concat2, RB3], dim=1)
out3 = self.reduction3(concat3)
return out3
@register('overnet')
class OverNet(nn.Module):
def __init__(self, upscale=5, group=4, *args, **kwargs):
super(OverNet, self).__init__()
wn = lambda x: torch.nn.utils.weight_norm(x)
self.upscale = upscale
# self.sub_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=True)
# self.add_mean = MeanShift((0.4488, 0.4371, 0.4040), sub=False)
self.entry_1 = wn(nn.Conv2d(3, 64, 3, 1, 1))
self.GDG1 = LDGs(64, 64, wn=wn)
self.GDG2 = LDGs(64, 64, wn=wn)
self.GDG3 = LDGs(64, 64, wn=wn)
self.reduction1 = BasicConv2d(wn, 64*2, 64, 1, 1, 0)
self.reduction2 = BasicConv2d(wn, 64*3, 64, 1, 1, 0)
self.reduction3 = BasicConv2d(wn, 64*4, 64, 1, 1, 0)
self.reduction = BasicConv2d(wn, 64*3, 64, 1, 1, 0)
self.Global_skip = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Conv2d(64, 64, 1, 1, 0), nn.ReLU(inplace=True))
self.upsample = UpsampleBlock(64, upscale=upscale, wn=wn, group=group)
self.exit1 = wn(nn.Conv2d(64, 3, 3, 1, 1))
self.res_scale = Scale(1)
self.x_scale = Scale(1)
def forward(self, x, out_size):
ori_h, ori_w = x.shape[-2:]
target_h, target_w = out_size
# x = self.sub_mean(x)
skip = x
x = self.entry_1(x)
c0 = o0 = x
GDG1 = self.GDG1(o0)
concat1 = torch.cat([c0, GDG1], dim=1)
out1 = self.reduction1(concat1)
GDG2 = self.GDG2(out1)
concat2 = torch.cat([concat1, GDG2], dim=1)
out2 = self.reduction2(concat2)
GDG3 = self.GDG3(out2)
concat3 = torch.cat([concat2, GDG3], dim=1)
out3 = self.reduction3(concat3)
output = self.reduction(torch.cat((out1, out2, out3),1))
output = self.res_scale(output) + self.x_scale(self.Global_skip(x))
output = self.upsample(output, upscale=self.upscale)
output = F.interpolate(output, out_size, mode='bicubic', align_corners=False)
skip = F.interpolate(skip, out_size, mode='bicubic', align_corners=False)
output = self.exit1(output) + skip
# output = self.add_mean(output)
return output