""" Misc functions, including distributed helpers. Mostly copy-paste from torchvision references. """ from typing import List, Optional import torch import torch.distributed as dist import torchvision from torch import Tensor if float(torchvision.__version__.split(".")[1]) < 7.0: from torchvision.ops import _new_empty_tensor from torchvision.ops.misc import _output_size def _max_by_axis(the_list): # type: (List[List[int]]) -> List[int] maxes = the_list[0] for sublist in the_list[1:]: for index, item in enumerate(sublist): maxes[index] = max(maxes[index], item) return maxes def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor """ Equivalent to nn.functional.interpolate, but with support for empty batch sizes. This will eventually be supported natively by PyTorch, and this class can go away. """ if float(torchvision.__version__.split(".")[1]) < 7.0: if input.numel() > 0: return torch.nn.functional.interpolate( input, size, scale_factor, mode, align_corners ) output_shape = _output_size(2, input, size, scale_factor) output_shape = list(input.shape[:-2]) + list(output_shape) return _new_empty_tensor(input, output_shape) else: return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) class NestedTensor(object): def __init__(self, tensors, mask: Optional[Tensor]): self.tensors = tensors self.mask = mask def to(self, device): # type: (Device) -> NestedTensor # noqa cast_tensor = self.tensors.to(device) mask = self.mask if mask is not None: assert mask is not None cast_mask = mask.to(device) else: cast_mask = None return NestedTensor(cast_tensor, cast_mask) def decompose(self): return self.tensors, self.mask def __repr__(self): return str(self.tensors) def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): # TODO make this more general if tensor_list[0].ndim == 3: if torchvision._is_tracing(): # nested_tensor_from_tensor_list() does not export well to ONNX # call _onnx_nested_tensor_from_tensor_list() instead return _onnx_nested_tensor_from_tensor_list(tensor_list) # TODO make it support different-sized images max_size = _max_by_axis([list(img.shape) for img in tensor_list]) # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) batch_shape = [len(tensor_list)] + max_size b, c, h, w = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device) for img, pad_img, m in zip(tensor_list, tensor, mask): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], : img.shape[2]] = False else: raise ValueError("not supported") return NestedTensor(tensor, mask) # _onnx_nested_tensor_from_tensor_list() is an implementation of # nested_tensor_from_tensor_list() that is supported by ONNX tracing. @torch.jit.unused def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: max_size = [] for i in range(tensor_list[0].dim()): max_size_i = torch.max( torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) ).to(torch.int64) max_size.append(max_size_i) max_size = tuple(max_size) # work around for # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) # m[: img.shape[1], :img.shape[2]] = False # which is not yet supported in onnx padded_imgs = [] padded_masks = [] for img in tensor_list: padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) padded_imgs.append(padded_img) m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) padded_masks.append(padded_mask.to(torch.bool)) tensor = torch.stack(padded_imgs) mask = torch.stack(padded_masks) return NestedTensor(tensor, mask=mask) def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True