|
|
|
import torch |
|
import torchvision |
|
from distutils.version import LooseVersion |
|
from torch import nn as nn |
|
from torch.nn import init as init |
|
from .dcn import ModulatedDeformConvPack, modulated_deform_conv |
|
|
|
|
|
class DCNv2Pack(ModulatedDeformConvPack): |
|
"""Modulated deformable conv for deformable alignment. |
|
|
|
Different from the official DCNv2Pack, which generates offsets and masks |
|
from the preceding features, this DCNv2Pack takes another different |
|
features to generate offsets and masks. |
|
|
|
``Paper: Delving Deep into Deformable Alignment in Video Super-Resolution`` |
|
""" |
|
|
|
def forward(self, x, feat): |
|
out = self.conv_offset(feat) |
|
o1, o2, mask = torch.chunk(out, 3, dim=1) |
|
offset = torch.cat((o1, o2), dim=1) |
|
mask = torch.sigmoid(mask) |
|
|
|
offset_absmean = torch.mean(torch.abs(offset)) |
|
if offset_absmean > 250: |
|
|
|
|
|
print(f'Offset abs mean is {offset_absmean}, larger than 50.') |
|
|
|
if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'): |
|
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, |
|
self.dilation, mask) |
|
else: |
|
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, |
|
self.dilation, self.groups, self.deformable_groups) |
|
|
|
|
|
|
|
class FlowGuidedDCN(ModulatedDeformConvPack): |
|
'''Use other features to generate offsets and masks''' |
|
|
|
|
|
def forward(self, x, feat, flows): |
|
'''input: input features for deformable conv: N, C, H, W. |
|
fea: other features used for generating offsets and mask: N, C, H, W. |
|
flows: N, 2, H, W. |
|
''' |
|
out = self.conv_offset(feat) |
|
o1, o2, mask = torch.chunk(out, 3, dim=1) |
|
mask = torch.sigmoid(mask) |
|
|
|
offset = torch.tanh(torch.cat((o1, o2), dim=1)) * 15 |
|
offset = offset + flows.flip(1).repeat(1, offset.size(1)//2, 1, 1) |
|
|
|
offset_mean = torch.mean(torch.abs(offset)) |
|
if offset_mean > 250: |
|
print('FlowGuidedDCN: Offset mean is {}, larger than 100.'.format(offset_mean)) |
|
|
|
|
|
|
|
|
|
if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'): |
|
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding, |
|
self.dilation, mask) |
|
else: |
|
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, |
|
self.dilation, self.groups, self.deformable_groups) |