Wrappers around on some nn functions, mainly to support empty tensors.
Ideally, add support directly in PyTorch to empty tensors in those functions.
These can be removed once https://github.com/pytorch/pytorch/issues/12013
is implemented
import functools
import warnings
from typing import List, Optional
import torch
from torch.nn import functional as F
from detectron2.utils.env import TORCH_VERSION
def shapes_to_tensor(x: List[int], device: Optional[torch.device] = None) -> torch.Tensor:
Turn a list of integer scalars or integer Tensor scalars into a vector,
in a way that's both traceable and scriptable.
In tracing, `x` should be a list of scalar Tensor, so the output can trace to the inputs.
In scripting or eager, `x` should be a list of int.
if torch.jit.is_scripting():
return torch.as_tensor(x, device=device)
if torch.jit.is_tracing():
assert all(
[isinstance(t, torch.Tensor) for t in x]
), "Shape should be tensor during tracing!"
ret = torch.stack(x)
if ret.device != device:
ret = ret.to(device=device)
return ret
return torch.as_tensor(x, device=device)
def check_if_dynamo_compiling():
if TORCH_VERSION >= (2, 1):
from torch._dynamo import is_compiling
return is_compiling()
return False
def disable_torch_compiler(func):
if TORCH_VERSION >= (2, 1):
def wrapper(*args, **kwargs):
return func(*args, **kwargs)
return wrapper
return func
def cat(tensors: List[torch.Tensor], dim: int = 0):
Efficient version of torch.cat that avoids a copy if there is only a single element in a list
assert isinstance(tensors, (list, tuple))
if len(tensors) == 1:
return tensors[0]
return torch.cat(tensors, dim)
def empty_input_loss_func_wrapper(loss_func):
def wrapped_loss_func(input, target, *, reduction="mean", **kwargs):
Same as `loss_func`, but returns 0 (instead of nan) for empty inputs.
if target.numel() == 0 and reduction == "mean":
return input.sum() * 0.0
return loss_func(input, target, reduction=reduction, **kwargs)
return wrapped_loss_func
cross_entropy = empty_input_loss_func_wrapper(F.cross_entropy)
class _NewEmptyTensorOp(torch.autograd.Function):
def forward(ctx, x, new_shape):
ctx.shape = x.shape
return x.new_empty(new_shape)
def backward(ctx, grad):
shape = ctx.shape
return _NewEmptyTensorOp.apply(grad, shape), None
class Conv2d(torch.nn.Conv2d):
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
def __init__(self, *args, **kwargs):
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:
norm (nn.Module, optional): a normalization layer
activation (callable(Tensor) -> Tensor): a callable activation function
It assumes that norm layer is used before activation.
norm = kwargs.pop("norm", None)
activation = kwargs.pop("activation", None)
super().__init__(*args, **kwargs)
self.norm = norm
self.activation = activation
def forward(self, x):
if not torch.jit.is_scripting():
is_dynamo_compiling = check_if_dynamo_compiling()
if not is_dynamo_compiling:
with warnings.catch_warnings(record=True):
if x.numel() == 0 and self.training:
assert not isinstance(
self.norm, torch.nn.SyncBatchNorm
), "SyncBatchNorm does not support empty inputs!"
x = F.conv2d(
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
if self.norm is not None:
x = self.norm(x)
if self.activation is not None:
x = self.activation(x)
return x
ConvTranspose2d = torch.nn.ConvTranspose2d
BatchNorm2d = torch.nn.BatchNorm2d
interpolate = F.interpolate
Linear = torch.nn.Linear
def nonzero_tuple(x):
A 'as_tuple=True' version of torch.nonzero to support torchscript.
because of https://github.com/pytorch/pytorch/issues/38718
if torch.jit.is_scripting():
if x.dim() == 0:
return x.unsqueeze(0).nonzero().unbind(1)
return x.nonzero().unbind(1)
return x.nonzero(as_tuple=True)
def move_device_like(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor:
Tracing friendly way to cast tensor to another tensor's device. Device will be treated
as constant during tracing, scripting the casting process as whole can workaround this issue.
return src.to(dst.device)