# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn import torch.nn.functional as F from mmengine.utils import digit_version from torch import Tensor from mmdet.registry import MODELS MODELS.register_module('Linear', module=nn.Linear) @MODELS.register_module(name='NormedLinear') class NormedLinear(nn.Linear): """Normalized Linear Layer. Args: tempeature (float, optional): Tempeature term. Defaults to 20. power (int, optional): Power term. Defaults to 1.0. eps (float, optional): The minimal value of divisor to keep numerical stability. Defaults to 1e-6. """ def __init__(self, *args, tempearture: float = 20, power: int = 1.0, eps: float = 1e-6, **kwargs) -> None: super().__init__(*args, **kwargs) self.tempearture = tempearture self.power = power self.eps = eps self.init_weights() def init_weights(self) -> None: """Initialize the weights.""" nn.init.normal_(self.weight, mean=0, std=0.01) if self.bias is not None: nn.init.constant_(self.bias, 0) def forward(self, x: Tensor) -> Tensor: """Forward function for `NormedLinear`.""" weight_ = self.weight / ( self.weight.norm(dim=1, keepdim=True).pow(self.power) + self.eps) x_ = x / (x.norm(dim=1, keepdim=True).pow(self.power) + self.eps) x_ = x_ * self.tempearture return F.linear(x_, weight_, self.bias) @MODELS.register_module(name='NormedConv2d') class NormedConv2d(nn.Conv2d): """Normalized Conv2d Layer. Args: tempeature (float, optional): Tempeature term. Defaults to 20. power (int, optional): Power term. Defaults to 1.0. eps (float, optional): The minimal value of divisor to keep numerical stability. Defaults to 1e-6. norm_over_kernel (bool, optional): Normalize over kernel. Defaults to False. """ def __init__(self, *args, tempearture: float = 20, power: int = 1.0, eps: float = 1e-6, norm_over_kernel: bool = False, **kwargs) -> None: super().__init__(*args, **kwargs) self.tempearture = tempearture self.power = power self.norm_over_kernel = norm_over_kernel self.eps = eps def forward(self, x: Tensor) -> Tensor: """Forward function for `NormedConv2d`.""" if not self.norm_over_kernel: weight_ = self.weight / ( self.weight.norm(dim=1, keepdim=True).pow(self.power) + self.eps) else: weight_ = self.weight / ( self.weight.view(self.weight.size(0), -1).norm( dim=1, keepdim=True).pow(self.power)[..., None, None] + self.eps) x_ = x / (x.norm(dim=1, keepdim=True).pow(self.power) + self.eps) x_ = x_ * self.tempearture if hasattr(self, 'conv2d_forward'): x_ = self.conv2d_forward(x_, weight_) else: if digit_version(torch.__version__) >= digit_version('1.8'): x_ = self._conv_forward(x_, weight_, self.bias) else: x_ = self._conv_forward(x_, weight_) return x_