Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn.bricks.wrappers import NewEmptyTensorOp, obsolete_torch_version | |
from mmdet.registry import MODELS | |
if torch.__version__ == 'parrots': | |
TORCH_VERSION = torch.__version__ | |
else: | |
# torch.__version__ could be 1.3.1+cu92, we only need the first two | |
# for comparison | |
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2]) | |
def adaptive_avg_pool2d(input, output_size): | |
"""Handle empty batch dimension to adaptive_avg_pool2d. | |
Args: | |
input (tensor): 4D tensor. | |
output_size (int, tuple[int,int]): the target output size. | |
""" | |
if input.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): | |
if isinstance(output_size, int): | |
output_size = [output_size, output_size] | |
output_size = [*input.shape[:2], *output_size] | |
empty = NewEmptyTensorOp.apply(input, output_size) | |
return empty | |
else: | |
return F.adaptive_avg_pool2d(input, output_size) | |
class AdaptiveAvgPool2d(nn.AdaptiveAvgPool2d): | |
"""Handle empty batch dimension to AdaptiveAvgPool2d.""" | |
def forward(self, x): | |
# PyTorch 1.9 does not support empty tensor inference yet | |
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)): | |
output_size = self.output_size | |
if isinstance(output_size, int): | |
output_size = [output_size, output_size] | |
else: | |
output_size = [ | |
v if v is not None else d | |
for v, d in zip(output_size, | |
x.size()[-2:]) | |
] | |
output_size = [*x.shape[:2], *output_size] | |
empty = NewEmptyTensorOp.apply(x, output_size) | |
return empty | |
return super().forward(x) | |
# Modified from | |
# https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py#L13 # noqa | |
class FrozenBatchNorm2d(nn.Module): | |
"""BatchNorm2d where the batch statistics and the affine parameters are | |
fixed. | |
It contains non-trainable buffers called | |
"weight" and "bias", "running_mean", "running_var", | |
initialized to perform identity transformation. | |
Args: | |
num_features (int): :math:`C` from an expected input of size | |
:math:`(N, C, H, W)`. | |
eps (float): a value added to the denominator for numerical stability. | |
Default: 1e-5 | |
""" | |
def __init__(self, num_features, eps=1e-5, **kwargs): | |
super().__init__() | |
self.num_features = num_features | |
self.eps = eps | |
self.register_buffer('weight', torch.ones(num_features)) | |
self.register_buffer('bias', torch.zeros(num_features)) | |
self.register_buffer('running_mean', torch.zeros(num_features)) | |
self.register_buffer('running_var', torch.ones(num_features) - eps) | |
def forward(self, x): | |
if x.requires_grad: | |
# When gradients are needed, F.batch_norm will use extra memory | |
# because its backward op computes gradients for weight/bias | |
# as well. | |
scale = self.weight * (self.running_var + self.eps).rsqrt() | |
bias = self.bias - self.running_mean * scale | |
scale = scale.reshape(1, -1, 1, 1) | |
bias = bias.reshape(1, -1, 1, 1) | |
out_dtype = x.dtype # may be half | |
return x * scale.to(out_dtype) + bias.to(out_dtype) | |
else: | |
# When gradients are not needed, F.batch_norm is a single fused op | |
# and provide more optimization opportunities. | |
return F.batch_norm( | |
x, | |
self.running_mean, | |
self.running_var, | |
self.weight, | |
self.bias, | |
training=False, | |
eps=self.eps, | |
) | |
def __repr__(self): | |
return 'FrozenBatchNorm2d(num_features={}, eps={})'.format( | |
self.num_features, self.eps) | |
def convert_frozen_batchnorm(cls, module): | |
"""Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. | |
Args: | |
module (torch.nn.Module): | |
Returns: | |
If module is BatchNorm/SyncBatchNorm, returns a new module. | |
Otherwise, in-place convert module and return it. | |
Similar to convert_sync_batchnorm in | |
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py | |
""" | |
bn_module = nn.modules.batchnorm | |
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) | |
res = module | |
if isinstance(module, bn_module): | |
res = cls(module.num_features) | |
if module.affine: | |
res.weight.data = module.weight.data.clone().detach() | |
res.bias.data = module.bias.data.clone().detach() | |
res.running_mean.data = module.running_mean.data | |
res.running_var.data = module.running_var.data | |
res.eps = module.eps | |
else: | |
for name, child in module.named_children(): | |
new_child = cls.convert_frozen_batchnorm(child) | |
if new_child is not child: | |
res.add_module(name, new_child) | |
return res | |