|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.nn.utils.weight_norm as wn |
|
|
|
|
|
|
|
|
|
|
|
def make_coord(shape, ranges=None, flatten=True): |
|
"""Make coordinates at grid centers.""" |
|
coord_seqs = [] |
|
for i, n in enumerate(shape): |
|
if ranges is None: |
|
v0, v1 = -1, 1 |
|
else: |
|
v0, v1 = ranges[i] |
|
r = (v1 - v0) / (2 * n) |
|
seq = v0 + r + (2 * r) * torch.arange(n).float() |
|
coord_seqs.append(seq) |
|
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) |
|
if flatten: |
|
ret = ret.view(-1, ret.shape[-1]) |
|
return ret |
|
|
|
|
|
class UPLayer_MS_V9(nn.Module): |
|
|
|
def __init__(self, n_feats, kSize, out_channels, interpolate_mode, levels=4): |
|
super().__init__() |
|
self.interpolate_mode = interpolate_mode |
|
self.levels = levels |
|
|
|
self.UPNet_x2_list = [] |
|
|
|
for _ in range(levels - 1): |
|
self.UPNet_x2_list.append( |
|
nn.Sequential( |
|
*[ |
|
nn.Conv2d( |
|
n_feats, |
|
n_feats * 4, |
|
kSize, |
|
padding=(kSize - 1) // 2, |
|
stride=1, |
|
), |
|
nn.PixelShuffle(2), |
|
] |
|
) |
|
) |
|
|
|
self.scale_aware_layer = nn.Sequential( |
|
*[nn.Linear(1, 64), nn.ReLU(), nn.Linear(64, levels), nn.Sigmoid()] |
|
) |
|
|
|
self.UPNet_x2_list = nn.Sequential(*self.UPNet_x2_list) |
|
|
|
self.fuse = nn.Sequential( |
|
*[ |
|
nn.Conv2d(n_feats * levels, 256, kernel_size=1, padding=0, stride=1), |
|
nn.ReLU(), |
|
nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1), |
|
nn.ReLU(), |
|
nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1), |
|
nn.ReLU(), |
|
nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1), |
|
nn.ReLU(), |
|
nn.Conv2d(256, out_channels, kernel_size=1, padding=0, stride=1), |
|
] |
|
) |
|
|
|
def forward(self, x, out_size): |
|
|
|
if type(out_size) == int: |
|
out_size = [out_size, out_size] |
|
|
|
if type(x) == list: |
|
return self.forward_list(x, out_size) |
|
|
|
r = torch.tensor([x.shape[2] / out_size[0]], device="cuda") |
|
|
|
scale_w = self.scale_aware_layer(r.unsqueeze(0))[0] |
|
|
|
|
|
|
|
x_list = [x] |
|
for l in range(1, self.levels): |
|
x_list.append(self.UPNet_x2_list[l - 1](x_list[l - 1])) |
|
|
|
x_resize_list = [] |
|
for l in range(self.levels): |
|
x_resize = F.interpolate( |
|
x_list[l], out_size, mode=self.interpolate_mode, align_corners=False |
|
) |
|
x_resize *= scale_w[l] |
|
x_resize_list.append(x_resize) |
|
|
|
|
|
out = self.fuse(torch.cat(tuple(x_resize_list), 1)) |
|
return out |
|
|
|
def forward_list(self, h_list, out_size): |
|
assert ( |
|
len(h_list) == self.levels |
|
), "The Length of input list must equal to the number of levels" |
|
device = h_list[0].device |
|
r = torch.tensor([h_list[0].shape[2] / out_size[0]], device=device) |
|
scale_w = self.scale_aware_layer(r.unsqueeze(0))[0] |
|
|
|
x_resize_list = [] |
|
for l in range(self.levels): |
|
h = h_list[l] |
|
for i in range(l): |
|
h = self.UPNet_x2_list[i](h) |
|
x_resize = F.interpolate( |
|
h, out_size, mode=self.interpolate_mode, align_corners=False |
|
) |
|
x_resize *= scale_w[l] |
|
x_resize_list.append(x_resize) |
|
|
|
out = self.fuse(torch.cat(tuple(x_resize_list), 1)) |
|
return out |
|
|
|
|
|
class UPLayer_MS_WN(nn.Module): |
|
|
|
def __init__(self, n_feats, kSize, out_channels, interpolate_mode, levels=4): |
|
super().__init__() |
|
self.interpolate_mode = interpolate_mode |
|
self.levels = levels |
|
self.UPNet_x2_list = [] |
|
|
|
for _ in range(levels - 1): |
|
self.UPNet_x2_list.append( |
|
nn.Sequential( |
|
*[ |
|
wn( |
|
nn.Conv2d( |
|
n_feats, |
|
n_feats * 4, |
|
kSize, |
|
padding=(kSize - 1) // 2, |
|
stride=1, |
|
) |
|
), |
|
nn.PixelShuffle(2), |
|
] |
|
) |
|
) |
|
|
|
self.scale_aware_layer = nn.Sequential( |
|
*[wn(nn.Linear(1, 64)), nn.ReLU(), wn(nn.Linear(64, levels)), nn.Sigmoid()] |
|
) |
|
|
|
self.UPNet_x2_list = nn.Sequential(*self.UPNet_x2_list) |
|
|
|
self.fuse = nn.Sequential( |
|
*[ |
|
wn( |
|
nn.Conv2d(n_feats * levels, 256, kernel_size=1, padding=0, stride=1) |
|
), |
|
nn.ReLU(), |
|
wn(nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1)), |
|
nn.ReLU(), |
|
wn(nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1)), |
|
nn.ReLU(), |
|
wn(nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1)), |
|
nn.ReLU(), |
|
wn(nn.Conv2d(256, out_channels, kernel_size=1, padding=0, stride=1)), |
|
] |
|
) |
|
|
|
assert self.interpolate_mode in ( |
|
"bilinear", |
|
"bicubic", |
|
"nearest", |
|
"MLP", |
|
), "Interpolate mode must be bilinear/bicubic/nearest/MLP" |
|
if self.interpolate_mode == "MLP": |
|
self.feature_interpolater = MLP_Interpolate(n_feats, radius=3) |
|
elif self.interpolate_mode == "nearest": |
|
self.feature_interpolater = lambda x, out_size: F.interpolate( |
|
x, out_size, mode=self.interpolate_mode |
|
) |
|
else: |
|
self.feature_interpolater = lambda x, out_size: F.interpolate( |
|
x, out_size, mode=self.interpolate_mode, align_corners=False |
|
) |
|
|
|
def forward(self, x, out_size): |
|
if type(out_size) == int: |
|
out_size = [out_size, out_size] |
|
|
|
if type(x) == list: |
|
return self.forward_list(x, out_size) |
|
|
|
r = torch.tensor([x.shape[2] / out_size[0]], device="cuda") |
|
|
|
scale_w = self.scale_aware_layer(r.unsqueeze(0))[0] |
|
|
|
x_list = [x] |
|
for l in range(1, self.levels): |
|
x_list.append(self.UPNet_x2_list[l - 1](x_list[l - 1])) |
|
|
|
x_resize_list = [] |
|
for l in range(self.levels): |
|
x_resize = self.feature_interpolater(x_list[l], out_size) |
|
x_resize *= scale_w[l] |
|
x_resize_list.append(x_resize) |
|
|
|
out = self.fuse(torch.cat(tuple(x_resize_list), 1)) |
|
return out |
|
|
|
def forward_list(self, h_list, out_size): |
|
assert ( |
|
len(h_list) == self.levels |
|
), "The Length of input list must equal to the number of levels" |
|
device = h_list[0].device |
|
r = torch.tensor([h_list[0].shape[2] / out_size[0]], device=device) |
|
scale_w = self.scale_aware_layer(r.unsqueeze(0))[0] |
|
|
|
x_resize_list = [] |
|
for l in range(self.levels): |
|
h = h_list[l] |
|
for i in range(l): |
|
h = self.UPNet_x2_list[i](h) |
|
x_resize = self.feature_interpolater(h, out_size) |
|
x_resize *= scale_w[l] |
|
x_resize_list.append(x_resize) |
|
|
|
out = self.fuse(torch.cat(tuple(x_resize_list), 1)) |
|
return out |
|
|
|
|
|
class UPLayer_MS_WN_woSA(UPLayer_MS_WN): |
|
def __init__(self, n_feats, kSize, out_channels, interpolate_mode, levels=4): |
|
super().__init__(n_feats, kSize, out_channels, interpolate_mode, levels) |
|
|
|
def forward(self, x, out_size): |
|
if type(out_size) == int: |
|
out_size = [out_size, out_size] |
|
|
|
if type(x) == list: |
|
return self.forward_list(x, out_size) |
|
|
|
x_list = [x] |
|
for l in range(1, self.levels): |
|
x_list.append(self.UPNet_x2_list[l - 1](x_list[l - 1])) |
|
|
|
x_resize_list = [] |
|
for l in range(self.levels): |
|
x_resize = self.feature_interpolater(x_list[l], out_size) |
|
x_resize_list.append(x_resize) |
|
|
|
out = self.fuse(torch.cat(tuple(x_resize_list), 1)) |
|
return out |
|
|
|
def forward_list(self, h_list, out_size): |
|
assert ( |
|
len(h_list) == self.levels |
|
), "The Length of input list must equal to the number of levels" |
|
|
|
x_resize_list = [] |
|
for l in range(self.levels): |
|
h = h_list[l] |
|
for i in range(l): |
|
h = self.UPNet_x2_list[i](h) |
|
x_resize = self.feature_interpolater(h, out_size) |
|
x_resize_list.append(x_resize) |
|
|
|
out = self.fuse(torch.cat(tuple(x_resize_list), 1)) |
|
return out |
|
|
|
|
|
class OSM(nn.Module): |
|
def __init__(self, n_feats, overscale): |
|
super().__init__() |
|
self.body = nn.Sequential( |
|
wn(nn.Conv2d(n_feats, 1600, 3, padding=1)), |
|
nn.PixelShuffle(overscale), |
|
wn(nn.Conv2d(64, 3, 3, padding=1)), |
|
) |
|
|
|
def forward(self, x, out_size): |
|
h = self.body(x) |
|
return F.interpolate(h, out_size, mode="bicubic", align_corners=False) |
|
|
|
|
|
class MLP_Interpolate(nn.Module): |
|
def __init__(self, n_feat, radius=2): |
|
super().__init__() |
|
self.radius = radius |
|
|
|
self.f_transfer = nn.Sequential( |
|
*[ |
|
nn.Linear(n_feat * self.radius * self.radius + 2, n_feat), |
|
nn.ReLU(True), |
|
nn.Linear(n_feat, n_feat), |
|
] |
|
) |
|
|
|
def forward(self, x, out_size): |
|
x_unfold = F.unfold(x, self.radius, padding=self.radius // 2) |
|
x_unfold = x_unfold.view( |
|
x.shape[0], x.shape[1] * (self.radius ** 2), x.shape[2], x.shape[3] |
|
) |
|
|
|
in_shape = x.shape[-2:] |
|
in_coord = ( |
|
make_coord(in_shape, flatten=False) |
|
.cuda() |
|
.permute(2, 0, 1) |
|
.unsqueeze(0) |
|
.expand(x.shape[0], 2, *in_shape) |
|
) |
|
|
|
if type(out_size) == int: |
|
out_size = [out_size, out_size] |
|
|
|
out_coord = make_coord(out_size, flatten=True).cuda() |
|
out_coord = out_coord.expand(x.shape[0], *out_coord.shape) |
|
|
|
q_feat = F.grid_sample( |
|
x_unfold, |
|
out_coord.flip(-1).unsqueeze(1), |
|
mode="nearest", |
|
align_corners=False, |
|
)[:, :, 0, :].permute(0, 2, 1) |
|
q_coord = F.grid_sample( |
|
in_coord, |
|
out_coord.flip(-1).unsqueeze(1), |
|
mode="nearest", |
|
align_corners=False, |
|
)[:, :, 0, :].permute(0, 2, 1) |
|
|
|
rel_coord = out_coord - q_coord |
|
rel_coord[:, :, 0] *= x.shape[-2] |
|
rel_coord[:, :, 1] *= x.shape[-1] |
|
|
|
inp = torch.cat([q_feat, rel_coord], dim=-1) |
|
|
|
bs, q = out_coord.shape[:2] |
|
pred = self.f_transfer(inp.view(bs * q, -1)).view(bs, q, -1) |
|
pred = ( |
|
pred.view(x.shape[0], *out_size, x.shape[1]) |
|
.permute(0, 3, 1, 2) |
|
.contiguous() |
|
) |
|
|
|
return pred |
|
|
|
|
|
class LIIF_Upsampler(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
raise NotImplementedError |
|
|
|
def forward(self): |
|
pass |
|
|