|
|
|
from torch import nn |
|
from torch.autograd import Function |
|
from torch.autograd.function import once_differentiable |
|
|
|
from tensormask import _C |
|
|
|
|
|
class _SwapAlign2Nat(Function): |
|
@staticmethod |
|
def forward(ctx, X, lambda_val, pad_val): |
|
ctx.lambda_val = lambda_val |
|
ctx.input_shape = X.size() |
|
|
|
Y = _C.swap_align2nat_forward(X, lambda_val, pad_val) |
|
return Y |
|
|
|
@staticmethod |
|
@once_differentiable |
|
def backward(ctx, gY): |
|
lambda_val = ctx.lambda_val |
|
bs, ch, h, w = ctx.input_shape |
|
|
|
gX = _C.swap_align2nat_backward(gY, lambda_val, bs, ch, h, w) |
|
|
|
return gX, None, None |
|
|
|
|
|
swap_align2nat = _SwapAlign2Nat.apply |
|
|
|
|
|
class SwapAlign2Nat(nn.Module): |
|
""" |
|
The op `SwapAlign2Nat` described in https://arxiv.org/abs/1903.12174. |
|
Given an input tensor that predicts masks of shape (N, C=VxU, H, W), |
|
apply the op, it will return masks of shape (N, V'xU', H', W') where |
|
the unit lengths of (V, U) and (H, W) are swapped, and the mask representation |
|
is transformed from aligned to natural. |
|
Args: |
|
lambda_val (int): the relative unit length ratio between (V, U) and (H, W), |
|
as we always have larger unit lengths for (V, U) than (H, W), |
|
lambda_val is always >= 1. |
|
pad_val (float): padding value for the values falling outside of the input |
|
tensor, default set to -6 as sigmoid(-6) is ~0, indicating |
|
that is no masks outside of the tensor. |
|
""" |
|
|
|
def __init__(self, lambda_val, pad_val=-6.0): |
|
super(SwapAlign2Nat, self).__init__() |
|
self.lambda_val = lambda_val |
|
self.pad_val = pad_val |
|
|
|
def forward(self, X): |
|
return swap_align2nat(X, self.lambda_val, self.pad_val) |
|
|
|
def __repr__(self): |
|
tmpstr = self.__class__.__name__ + "(" |
|
tmpstr += "lambda_val=" + str(self.lambda_val) |
|
tmpstr += ", pad_val=" + str(self.pad_val) |
|
tmpstr += ")" |
|
return tmpstr |
|
|