FunSR / models /baselines /upsampler.py
KyanChen's picture
add
02c5426
raw
history blame
11.6 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.weight_norm as wn
################
# Upsampler
################
def make_coord(shape, ranges=None, flatten=True):
"""Make coordinates at grid centers."""
coord_seqs = []
for i, n in enumerate(shape):
if ranges is None:
v0, v1 = -1, 1
else:
v0, v1 = ranges[i]
r = (v1 - v0) / (2 * n)
seq = v0 + r + (2 * r) * torch.arange(n).float()
coord_seqs.append(seq)
ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
if flatten:
ret = ret.view(-1, ret.shape[-1])
return ret
class UPLayer_MS_V9(nn.Module):
# Up-sampling net
def __init__(self, n_feats, kSize, out_channels, interpolate_mode, levels=4):
super().__init__()
self.interpolate_mode = interpolate_mode
self.levels = levels
self.UPNet_x2_list = []
for _ in range(levels - 1):
self.UPNet_x2_list.append(
nn.Sequential(
*[
nn.Conv2d(
n_feats,
n_feats * 4,
kSize,
padding=(kSize - 1) // 2,
stride=1,
),
nn.PixelShuffle(2),
]
)
)
self.scale_aware_layer = nn.Sequential(
*[nn.Linear(1, 64), nn.ReLU(), nn.Linear(64, levels), nn.Sigmoid()]
)
self.UPNet_x2_list = nn.Sequential(*self.UPNet_x2_list)
self.fuse = nn.Sequential(
*[
nn.Conv2d(n_feats * levels, 256, kernel_size=1, padding=0, stride=1),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1),
nn.ReLU(),
nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1),
nn.ReLU(),
nn.Conv2d(256, out_channels, kernel_size=1, padding=0, stride=1),
]
)
def forward(self, x, out_size):
if type(out_size) == int:
out_size = [out_size, out_size]
if type(x) == list:
return self.forward_list(x, out_size)
r = torch.tensor([x.shape[2] / out_size[0]], device="cuda")
scale_w = self.scale_aware_layer(r.unsqueeze(0))[0]
# scale_in = x.new_tensor(np.ones([x.shape[0], 1, out_size[0], out_size[1]])*r)
x_list = [x]
for l in range(1, self.levels):
x_list.append(self.UPNet_x2_list[l - 1](x_list[l - 1]))
x_resize_list = []
for l in range(self.levels):
x_resize = F.interpolate(
x_list[l], out_size, mode=self.interpolate_mode, align_corners=False
)
x_resize *= scale_w[l]
x_resize_list.append(x_resize)
# x_resize_list.append(scale_in)
out = self.fuse(torch.cat(tuple(x_resize_list), 1))
return out
def forward_list(self, h_list, out_size):
assert (
len(h_list) == self.levels
), "The Length of input list must equal to the number of levels"
device = h_list[0].device
r = torch.tensor([h_list[0].shape[2] / out_size[0]], device=device)
scale_w = self.scale_aware_layer(r.unsqueeze(0))[0]
x_resize_list = []
for l in range(self.levels):
h = h_list[l]
for i in range(l):
h = self.UPNet_x2_list[i](h)
x_resize = F.interpolate(
h, out_size, mode=self.interpolate_mode, align_corners=False
)
x_resize *= scale_w[l]
x_resize_list.append(x_resize)
out = self.fuse(torch.cat(tuple(x_resize_list), 1))
return out
class UPLayer_MS_WN(nn.Module):
# Up-sampling net
def __init__(self, n_feats, kSize, out_channels, interpolate_mode, levels=4):
super().__init__()
self.interpolate_mode = interpolate_mode
self.levels = levels
self.UPNet_x2_list = []
for _ in range(levels - 1):
self.UPNet_x2_list.append(
nn.Sequential(
*[
wn(
nn.Conv2d(
n_feats,
n_feats * 4,
kSize,
padding=(kSize - 1) // 2,
stride=1,
)
),
nn.PixelShuffle(2),
]
)
)
self.scale_aware_layer = nn.Sequential(
*[wn(nn.Linear(1, 64)), nn.ReLU(), wn(nn.Linear(64, levels)), nn.Sigmoid()]
)
self.UPNet_x2_list = nn.Sequential(*self.UPNet_x2_list)
self.fuse = nn.Sequential(
*[
wn(
nn.Conv2d(n_feats * levels, 256, kernel_size=1, padding=0, stride=1)
),
nn.ReLU(),
wn(nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1)),
nn.ReLU(),
wn(nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1)),
nn.ReLU(),
wn(nn.Conv2d(256, 256, kernel_size=1, padding=0, stride=1)),
nn.ReLU(),
wn(nn.Conv2d(256, out_channels, kernel_size=1, padding=0, stride=1)),
]
)
assert self.interpolate_mode in (
"bilinear",
"bicubic",
"nearest",
"MLP",
), "Interpolate mode must be bilinear/bicubic/nearest/MLP"
if self.interpolate_mode == "MLP":
self.feature_interpolater = MLP_Interpolate(n_feats, radius=3)
elif self.interpolate_mode == "nearest":
self.feature_interpolater = lambda x, out_size: F.interpolate(
x, out_size, mode=self.interpolate_mode
)
else:
self.feature_interpolater = lambda x, out_size: F.interpolate(
x, out_size, mode=self.interpolate_mode, align_corners=False
)
def forward(self, x, out_size):
if type(out_size) == int:
out_size = [out_size, out_size]
if type(x) == list:
return self.forward_list(x, out_size)
r = torch.tensor([x.shape[2] / out_size[0]], device="cuda")
scale_w = self.scale_aware_layer(r.unsqueeze(0))[0]
x_list = [x]
for l in range(1, self.levels):
x_list.append(self.UPNet_x2_list[l - 1](x_list[l - 1]))
x_resize_list = []
for l in range(self.levels):
x_resize = self.feature_interpolater(x_list[l], out_size)
x_resize *= scale_w[l]
x_resize_list.append(x_resize)
out = self.fuse(torch.cat(tuple(x_resize_list), 1))
return out
def forward_list(self, h_list, out_size):
assert (
len(h_list) == self.levels
), "The Length of input list must equal to the number of levels"
device = h_list[0].device
r = torch.tensor([h_list[0].shape[2] / out_size[0]], device=device)
scale_w = self.scale_aware_layer(r.unsqueeze(0))[0]
x_resize_list = []
for l in range(self.levels):
h = h_list[l]
for i in range(l):
h = self.UPNet_x2_list[i](h)
x_resize = self.feature_interpolater(h, out_size)
x_resize *= scale_w[l]
x_resize_list.append(x_resize)
out = self.fuse(torch.cat(tuple(x_resize_list), 1))
return out
class UPLayer_MS_WN_woSA(UPLayer_MS_WN):
def __init__(self, n_feats, kSize, out_channels, interpolate_mode, levels=4):
super().__init__(n_feats, kSize, out_channels, interpolate_mode, levels)
def forward(self, x, out_size):
if type(out_size) == int:
out_size = [out_size, out_size]
if type(x) == list:
return self.forward_list(x, out_size)
x_list = [x]
for l in range(1, self.levels):
x_list.append(self.UPNet_x2_list[l - 1](x_list[l - 1]))
x_resize_list = []
for l in range(self.levels):
x_resize = self.feature_interpolater(x_list[l], out_size)
x_resize_list.append(x_resize)
out = self.fuse(torch.cat(tuple(x_resize_list), 1))
return out
def forward_list(self, h_list, out_size):
assert (
len(h_list) == self.levels
), "The Length of input list must equal to the number of levels"
x_resize_list = []
for l in range(self.levels):
h = h_list[l]
for i in range(l):
h = self.UPNet_x2_list[i](h)
x_resize = self.feature_interpolater(h, out_size)
x_resize_list.append(x_resize)
out = self.fuse(torch.cat(tuple(x_resize_list), 1))
return out
class OSM(nn.Module):
def __init__(self, n_feats, overscale):
super().__init__()
self.body = nn.Sequential(
wn(nn.Conv2d(n_feats, 1600, 3, padding=1)),
nn.PixelShuffle(overscale),
wn(nn.Conv2d(64, 3, 3, padding=1)),
)
def forward(self, x, out_size):
h = self.body(x)
return F.interpolate(h, out_size, mode="bicubic", align_corners=False)
class MLP_Interpolate(nn.Module):
def __init__(self, n_feat, radius=2):
super().__init__()
self.radius = radius
self.f_transfer = nn.Sequential(
*[
nn.Linear(n_feat * self.radius * self.radius + 2, n_feat),
nn.ReLU(True),
nn.Linear(n_feat, n_feat),
]
)
def forward(self, x, out_size):
x_unfold = F.unfold(x, self.radius, padding=self.radius // 2)
x_unfold = x_unfold.view(
x.shape[0], x.shape[1] * (self.radius ** 2), x.shape[2], x.shape[3]
)
in_shape = x.shape[-2:]
in_coord = (
make_coord(in_shape, flatten=False)
.cuda()
.permute(2, 0, 1)
.unsqueeze(0)
.expand(x.shape[0], 2, *in_shape)
)
if type(out_size) == int:
out_size = [out_size, out_size]
out_coord = make_coord(out_size, flatten=True).cuda()
out_coord = out_coord.expand(x.shape[0], *out_coord.shape)
q_feat = F.grid_sample(
x_unfold,
out_coord.flip(-1).unsqueeze(1),
mode="nearest",
align_corners=False,
)[:, :, 0, :].permute(0, 2, 1)
q_coord = F.grid_sample(
in_coord,
out_coord.flip(-1).unsqueeze(1),
mode="nearest",
align_corners=False,
)[:, :, 0, :].permute(0, 2, 1)
rel_coord = out_coord - q_coord
rel_coord[:, :, 0] *= x.shape[-2]
rel_coord[:, :, 1] *= x.shape[-1]
inp = torch.cat([q_feat, rel_coord], dim=-1)
bs, q = out_coord.shape[:2]
pred = self.f_transfer(inp.view(bs * q, -1)).view(bs, q, -1)
pred = (
pred.view(x.shape[0], *out_size, x.shape[1])
.permute(0, 3, 1, 2)
.contiguous()
)
return pred
class LIIF_Upsampler(nn.Module):
def __init__(self):
super().__init__()
raise NotImplementedError
def forward(self):
pass