Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from mmcv.cnn import ConvModule | |
| from mmengine.model import BaseModule | |
| from torch.utils.checkpoint import checkpoint | |
| from mmdet.registry import MODELS | |
| class HRFPN(BaseModule): | |
| """HRFPN (High Resolution Feature Pyramids) | |
| paper: `High-Resolution Representations for Labeling Pixels and Regions | |
| <https://arxiv.org/abs/1904.04514>`_. | |
| Args: | |
| in_channels (list): number of channels for each branch. | |
| out_channels (int): output channels of feature pyramids. | |
| num_outs (int): number of output stages. | |
| pooling_type (str): pooling for generating feature pyramids | |
| from {MAX, AVG}. | |
| conv_cfg (dict): dictionary to construct and config conv layer. | |
| norm_cfg (dict): dictionary to construct and config norm layer. | |
| with_cp (bool): Use checkpoint or not. Using checkpoint will save some | |
| memory while slowing down the training speed. | |
| stride (int): stride of 3x3 convolutional layers | |
| init_cfg (dict or list[dict], optional): Initialization config dict. | |
| """ | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| num_outs=5, | |
| pooling_type='AVG', | |
| conv_cfg=None, | |
| norm_cfg=None, | |
| with_cp=False, | |
| stride=1, | |
| init_cfg=dict(type='Caffe2Xavier', layer='Conv2d')): | |
| super(HRFPN, self).__init__(init_cfg) | |
| assert isinstance(in_channels, list) | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.num_ins = len(in_channels) | |
| self.num_outs = num_outs | |
| self.with_cp = with_cp | |
| self.conv_cfg = conv_cfg | |
| self.norm_cfg = norm_cfg | |
| self.reduction_conv = ConvModule( | |
| sum(in_channels), | |
| out_channels, | |
| kernel_size=1, | |
| conv_cfg=self.conv_cfg, | |
| act_cfg=None) | |
| self.fpn_convs = nn.ModuleList() | |
| for i in range(self.num_outs): | |
| self.fpn_convs.append( | |
| ConvModule( | |
| out_channels, | |
| out_channels, | |
| kernel_size=3, | |
| padding=1, | |
| stride=stride, | |
| conv_cfg=self.conv_cfg, | |
| act_cfg=None)) | |
| if pooling_type == 'MAX': | |
| self.pooling = F.max_pool2d | |
| else: | |
| self.pooling = F.avg_pool2d | |
| def forward(self, inputs): | |
| """Forward function.""" | |
| assert len(inputs) == self.num_ins | |
| outs = [inputs[0]] | |
| for i in range(1, self.num_ins): | |
| outs.append( | |
| F.interpolate(inputs[i], scale_factor=2**i, mode='bilinear')) | |
| out = torch.cat(outs, dim=1) | |
| if out.requires_grad and self.with_cp: | |
| out = checkpoint(self.reduction_conv, out) | |
| else: | |
| out = self.reduction_conv(out) | |
| outs = [out] | |
| for i in range(1, self.num_outs): | |
| outs.append(self.pooling(out, kernel_size=2**i, stride=2**i)) | |
| outputs = [] | |
| for i in range(self.num_outs): | |
| if outs[i].requires_grad and self.with_cp: | |
| tmp_out = checkpoint(self.fpn_convs[i], outs[i]) | |
| else: | |
| tmp_out = self.fpn_convs[i](outs[i]) | |
| outputs.append(tmp_out) | |
| return tuple(outputs) | |