Spaces:
Runtime error
Runtime error
File size: 7,444 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
# Copyright (c) OpenMMLab. All rights reserved.
import re
from typing import Optional, Tuple
import torch.nn as nn
from mmpretrain.models.utils.sparse_modules import (SparseAvgPooling,
SparseBatchNorm2d,
SparseConv2d,
SparseMaxPooling,
SparseSyncBatchNorm2d)
from mmpretrain.registry import MODELS
from .resnet import ResNet
@MODELS.register_module()
class SparseResNet(ResNet):
"""ResNet with sparse module conversion function.
Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py
Args:
depth (int): Network depth, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input image channels. Defaults to 3.
stem_channels (int): Output channels of the stem layer. Defaults to 64.
base_channels (int): Middle channels of the first stage.
Defaults to 64.
num_stages (int): Stages of the network. Defaults to 4.
strides (Sequence[int]): Strides of the first block of each stage.
Defaults to ``(1, 2, 2, 2)``.
dilations (Sequence[int]): Dilation of each stage.
Defaults to ``(1, 1, 1, 1)``.
out_indices (Sequence[int]): Output from which stages.
Defaults to ``(3, )``.
style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
layer is the 3x3 conv layer, otherwise the stride-two layer is
the first 1x1 conv layer.
deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv.
Defaults to False.
avg_down (bool): Use AvgPool instead of stride conv when
downsampling in the bottleneck. Defaults to False.
frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
-1 means not freezing any parameters. Defaults to -1.
conv_cfg (dict | None): The config dict for conv layers.
Defaults to None.
norm_cfg (dict): The config dict for norm layers.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
zero_init_residual (bool): Whether to use zero init for last norm layer
in resblocks to let them behave as identity. Defaults to True.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
"""
def __init__(self,
depth: int,
in_channels: int = 3,
stem_channels: int = 64,
base_channels: int = 64,
expansion: Optional[int] = None,
num_stages: int = 4,
strides: Tuple[int] = (1, 2, 2, 2),
dilations: Tuple[int] = (1, 1, 1, 1),
out_indices: Tuple[int] = (3, ),
style: str = 'pytorch',
deep_stem: bool = False,
avg_down: bool = False,
frozen_stages: int = -1,
conv_cfg: Optional[dict] = None,
norm_cfg: dict = dict(type='SparseSyncBatchNorm2d'),
norm_eval: bool = False,
with_cp: bool = False,
zero_init_residual: bool = False,
init_cfg: Optional[dict] = [
dict(type='Kaiming', layer=['Conv2d']),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
],
drop_path_rate: float = 0,
**kwargs):
super().__init__(
depth=depth,
in_channels=in_channels,
stem_channels=stem_channels,
base_channels=base_channels,
expansion=expansion,
num_stages=num_stages,
strides=strides,
dilations=dilations,
out_indices=out_indices,
style=style,
deep_stem=deep_stem,
avg_down=avg_down,
frozen_stages=frozen_stages,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
norm_eval=norm_eval,
with_cp=with_cp,
zero_init_residual=zero_init_residual,
init_cfg=init_cfg,
drop_path_rate=drop_path_rate,
**kwargs)
norm_type = norm_cfg['type']
enable_sync_bn = False
if re.search('Sync', norm_type) is not None:
enable_sync_bn = True
self.dense_model_to_sparse(m=self, enable_sync_bn=enable_sync_bn)
def dense_model_to_sparse(self, m: nn.Module,
enable_sync_bn: bool) -> nn.Module:
"""Convert regular dense modules to sparse modules."""
output = m
if isinstance(m, nn.Conv2d):
m: nn.Conv2d
bias = m.bias is not None
output = SparseConv2d(
m.in_channels,
m.out_channels,
kernel_size=m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
groups=m.groups,
bias=bias,
padding_mode=m.padding_mode,
)
output.weight.data.copy_(m.weight.data)
if bias:
output.bias.data.copy_(m.bias.data)
elif isinstance(m, nn.MaxPool2d):
m: nn.MaxPool2d
output = SparseMaxPooling(
m.kernel_size,
stride=m.stride,
padding=m.padding,
dilation=m.dilation,
return_indices=m.return_indices,
ceil_mode=m.ceil_mode)
elif isinstance(m, nn.AvgPool2d):
m: nn.AvgPool2d
output = SparseAvgPooling(
m.kernel_size,
m.stride,
m.padding,
ceil_mode=m.ceil_mode,
count_include_pad=m.count_include_pad,
divisor_override=m.divisor_override)
elif isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
m: nn.BatchNorm2d
output = (SparseSyncBatchNorm2d
if enable_sync_bn else SparseBatchNorm2d)(
m.weight.shape[0],
eps=m.eps,
momentum=m.momentum,
affine=m.affine,
track_running_stats=m.track_running_stats)
output.weight.data.copy_(m.weight.data)
output.bias.data.copy_(m.bias.data)
output.running_mean.data.copy_(m.running_mean.data)
output.running_var.data.copy_(m.running_var.data)
output.num_batches_tracked.data.copy_(m.num_batches_tracked.data)
elif isinstance(m, (nn.Conv1d, )):
raise NotImplementedError
for name, child in m.named_children():
output.add_module(
name,
self.dense_model_to_sparse(
child, enable_sync_bn=enable_sync_bn))
del m
return output
|