Spaces:
Runtime error
Runtime error
File size: 2,539 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import build_activation_layer, build_norm_layer
from mmengine.model import BaseModule
from mmdet.registry import MODELS
@MODELS.register_module()
class FcModule(BaseModule):
"""Fully-connected layer module.
Args:
in_channels (int): Input channels.
out_channels (int): Ourput channels.
norm_cfg (dict, optional): Configuration of normlization method
after fc. Defaults to None.
act_cfg (dict, optional): Configuration of activation method after fc.
Defaults to dict(type='ReLU').
inplace (bool, optional): Whether inplace the activatation module.
Defaults to True.
init_cfg (dict, optional): Initialization config dict.
Defaults to dict(type='Kaiming', layer='Linear').
"""
def __init__(self,
in_channels: int,
out_channels: int,
norm_cfg: dict = None,
act_cfg: dict = dict(type='ReLU'),
inplace: bool = True,
init_cfg=dict(type='Kaiming', layer='Linear')):
super(FcModule, self).__init__(init_cfg)
assert norm_cfg is None or isinstance(norm_cfg, dict)
assert act_cfg is None or isinstance(act_cfg, dict)
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.inplace = inplace
self.with_norm = norm_cfg is not None
self.with_activation = act_cfg is not None
self.fc = nn.Linear(in_channels, out_channels)
# build normalization layers
if self.with_norm:
self.norm_name, norm = build_norm_layer(norm_cfg, out_channels)
self.add_module(self.norm_name, norm)
# build activation layer
if self.with_activation:
act_cfg_ = act_cfg.copy()
# nn.Tanh has no 'inplace' argument
if act_cfg_['type'] not in [
'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish'
]:
act_cfg_.setdefault('inplace', inplace)
self.activate = build_activation_layer(act_cfg_)
@property
def norm(self):
"""Normalization."""
return getattr(self, self.norm_name)
def forward(self, x, activate=True, norm=True):
"""Model forward."""
x = self.fc(x)
if norm and self.with_norm:
x = self.norm(x)
if activate and self.with_activation:
x = self.activate(x)
return x
|