"""Custom replacement for `torch.nn.functional.conv2d` that supports |
arbitrarily high order gradients with zero performance penalty.""" |
import warnings |
import contextlib |
import torch |
from distutils.version import LooseVersion |
enabled = False |
weight_gradients_disabled = False |
old_version = LooseVersion(torch.__version__) < LooseVersion('1.11.0') |
@contextlib.contextmanager |
def no_weight_gradients(): |
global weight_gradients_disabled |
old = weight_gradients_disabled |
weight_gradients_disabled = True |
yield |
weight_gradients_disabled = old |
def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): |
if _should_use_custom_op(input): |
return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) |
return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) |
def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): |
if _should_use_custom_op(input): |
return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) |
return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) |
def _should_use_custom_op(input): |
assert isinstance(input, torch.Tensor) |
if (not enabled) or (not torch.backends.cudnn.enabled): |
return False |
if input.device.type != 'cuda': |
return False |
if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'): |
return True |
warnings.warn(f'conv2d_gradfix not supported on PyTorch {torch.__version__}. Falling back to torch.nn.functional.conv2d().') |
return False |
def _tuple_of_ints(xs, ndim): |
xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim |
assert len(xs) == ndim |
assert all(isinstance(x, int) for x in xs) |
return xs |
_conv2d_gradfix_cache = dict() |
def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): |
ndim = 2 |
weight_shape = tuple(weight_shape) |
stride = _tuple_of_ints(stride, ndim) |
padding = _tuple_of_ints(padding, ndim) |
output_padding = _tuple_of_ints(output_padding, ndim) |
dilation = _tuple_of_ints(dilation, ndim) |
key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) |
if key in _conv2d_gradfix_cache: |
return _conv2d_gradfix_cache[key] |
assert groups >= 1 |
assert len(weight_shape) == ndim + 2 |
assert all(stride[i] >= 1 for i in range(ndim)) |
assert all(padding[i] >= 0 for i in range(ndim)) |
assert all(dilation[i] >= 0 for i in range(ndim)) |
if not transpose: |
assert all(output_padding[i] == 0 for i in range(ndim)) |
else: |
assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) |
common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) |
def calc_output_padding(input_shape, output_shape): |
if transpose: |
return [0, 0] |
return [ |
input_shape[i + 2] |
- (output_shape[i + 2] - 1) * stride[i] |
- (1 - 2 * padding[i]) |
- dilation[i] * (weight_shape[i + 2] - 1) |
for i in range(ndim) |
] |
class Conv2d(torch.autograd.Function): |
@staticmethod |
def forward(ctx, input, weight, bias): |
assert weight.shape == weight_shape |
if not transpose: |
output = torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) |
else: |
output = torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) |
ctx.save_for_backward(input, weight, bias) |
return output |
@staticmethod |
def backward(ctx, grad_output): |
input, weight, bias = ctx.saved_tensors |
grad_input = None |
grad_weight = None |
grad_bias = None |
if ctx.needs_input_grad[0]: |
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) |
grad_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, weight, None) |
assert grad_input.shape == input.shape |
if ctx.needs_input_grad[1] and not weight_gradients_disabled: |
grad_weight = Conv2dGradWeight.apply(grad_output, input, bias) |
assert grad_weight.shape == weight_shape |
if ctx.needs_input_grad[2]: |
grad_bias = grad_output.sum([0, 2, 3]) |
return grad_input, grad_weight, grad_bias |
class Conv2dGradWeight(torch.autograd.Function): |
@staticmethod |
def forward(ctx, grad_output, input, bias): |
if old_version: |
op = torch._C._jit_get_operation( |
'aten::cudnn_convolution_backward_weight' if not transpose else 'aten::cudnn_convolution_transpose_backward_weight') |
flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, |
torch.backends.cudnn.allow_tf32] |
grad_weight = op(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) |
else: |
bias_shape = bias.shape if (bias is not None) else None |
empty_weight = torch.empty(weight_shape, dtype=input.dtype, layout=input.layout, device=input.device) |
grad_weight = torch.ops.aten.convolution_backward(grad_output, input, empty_weight, bias_sizes=bias_shape, stride=stride, padding=padding, dilation=dilation, transposed=transpose, output_padding=output_padding, groups=groups, output_mask=[0,1,0])[1] |
assert grad_weight.shape == weight_shape |
ctx.save_for_backward(grad_output, input) |
return grad_weight |
@staticmethod |
def backward(ctx, grad2_grad_weight): |
grad_output, input = ctx.saved_tensors |
grad2_grad_output = None |
grad2_input = None |
if ctx.needs_input_grad[0]: |
grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) |
assert grad2_grad_output.shape == grad_output.shape |
if ctx.needs_input_grad[1]: |
p = calc_output_padding(input_shape=input.shape, output_shape=grad_output.shape) |
grad2_input = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs).apply(grad_output, grad2_grad_weight, None) |
assert grad2_input.shape == input.shape |
return grad2_grad_output, grad2_input |
_conv2d_gradfix_cache[key] = Conv2d |
return Conv2d |