|
import copy
|
|
import warnings
|
|
from mmcv.cnn.bricks.registry import (TRANSFORMER_LAYER)
|
|
import torch
|
|
from mmcv.runner.base_module import BaseModule, ModuleList
|
|
from mmcv.cnn import build_norm_layer
|
|
from mmcv.cnn.bricks.transformer import build_feedforward_network, build_attention
|
|
from mmcv import ConfigDict
|
|
|
|
@TRANSFORMER_LAYER.register_module()
|
|
class MyCustomBaseTransformerLayer(BaseModule):
|
|
"""Base `TransformerLayer` for vision transformer.
|
|
It can be built from `mmcv.ConfigDict` and support more flexible
|
|
customization, for example, using any number of `FFN or LN ` and
|
|
use different kinds of `attention` by specifying a list of `ConfigDict`
|
|
named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
|
|
when you specifying `norm` as the first element of `operation_order`.
|
|
More details about the `prenorm`: `On Layer Normalization in the
|
|
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
|
|
Args:
|
|
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
|
|
Configs for `self_attention` or `cross_attention` modules,
|
|
The order of the configs in the list should be consistent with
|
|
corresponding attentions in operation_order.
|
|
If it is a dict, all of the attention modules in operation_order
|
|
will be built with this config. Default: None.
|
|
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
|
|
Configs for FFN, The order of the configs in the list should be
|
|
consistent with corresponding ffn in operation_order.
|
|
If it is a dict, all of the attention modules in operation_order
|
|
will be built with this config.
|
|
operation_order (tuple[str]): The execution order of operation
|
|
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
|
|
Support `prenorm` when you specifying first element as `norm`.
|
|
Default:None.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='LN').
|
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
|
Default: None.
|
|
batch_first (bool): Key, Query and Value are shape
|
|
of (batch, n, embed_dim)
|
|
or (n, batch, embed_dim). Default to False.
|
|
"""
|
|
|
|
def __init__(self,
|
|
attn_cfgs=None,
|
|
ffn_cfgs=dict(
|
|
type='FFN',
|
|
embed_dims=256,
|
|
feedforward_channels=1024,
|
|
num_fcs=2,
|
|
ffn_drop=0.,
|
|
act_cfg=dict(type='ReLU', inplace=True),
|
|
),
|
|
operation_order=None,
|
|
norm_cfg=dict(type='LN'),
|
|
init_cfg=None,
|
|
batch_first=True,
|
|
**kwargs):
|
|
|
|
deprecated_args = dict(
|
|
feedforward_channels='feedforward_channels',
|
|
ffn_dropout='ffn_drop',
|
|
ffn_num_fcs='num_fcs')
|
|
for ori_name, new_name in deprecated_args.items():
|
|
if ori_name in kwargs:
|
|
warnings.warn(
|
|
f'The arguments `{ori_name}` in BaseTransformerLayer '
|
|
f'has been deprecated, now you should set `{new_name}` '
|
|
f'and other FFN related arguments '
|
|
f'to a dict named `ffn_cfgs`. ')
|
|
ffn_cfgs[new_name] = kwargs[ori_name]
|
|
|
|
super(MyCustomBaseTransformerLayer, self).__init__(init_cfg)
|
|
|
|
self.batch_first = batch_first
|
|
|
|
assert set(operation_order) & set(
|
|
['self_attn', 'norm', 'ffn', 'cross_attn']) == \
|
|
set(operation_order), f'The operation_order of' \
|
|
f' {self.__class__.__name__} should ' \
|
|
f'contains all four operation type ' \
|
|
f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
|
|
|
|
num_attn = operation_order.count('self_attn') + operation_order.count(
|
|
'cross_attn')
|
|
if isinstance(attn_cfgs, dict):
|
|
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
|
|
else:
|
|
assert num_attn == len(attn_cfgs), f'The length ' \
|
|
f'of attn_cfg {num_attn} is ' \
|
|
f'not consistent with the number of attention' \
|
|
f'in operation_order {operation_order}.'
|
|
|
|
self.num_attn = num_attn
|
|
self.operation_order = operation_order
|
|
self.norm_cfg = norm_cfg
|
|
self.pre_norm = operation_order[0] == 'norm'
|
|
self.attentions = ModuleList()
|
|
|
|
index = 0
|
|
for operation_name in operation_order:
|
|
if operation_name in ['self_attn', 'cross_attn']:
|
|
if 'batch_first' in attn_cfgs[index]:
|
|
assert self.batch_first == attn_cfgs[index]['batch_first']
|
|
else:
|
|
attn_cfgs[index]['batch_first'] = self.batch_first
|
|
attention = build_attention(attn_cfgs[index])
|
|
|
|
|
|
attention.operation_name = operation_name
|
|
self.attentions.append(attention)
|
|
index += 1
|
|
|
|
self.embed_dims = self.attentions[0].embed_dims
|
|
|
|
self.ffns = ModuleList()
|
|
num_ffns = operation_order.count('ffn')
|
|
if isinstance(ffn_cfgs, dict):
|
|
ffn_cfgs = ConfigDict(ffn_cfgs)
|
|
if isinstance(ffn_cfgs, dict):
|
|
ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
|
|
assert len(ffn_cfgs) == num_ffns
|
|
for ffn_index in range(num_ffns):
|
|
if 'embed_dims' not in ffn_cfgs[ffn_index]:
|
|
ffn_cfgs['embed_dims'] = self.embed_dims
|
|
else:
|
|
ffn_cfgs[ffn_index]['embed_dims'] = self.embed_dims
|
|
|
|
self.ffns.append(
|
|
build_feedforward_network(ffn_cfgs[ffn_index]))
|
|
|
|
self.norms = ModuleList()
|
|
num_norms = operation_order.count('norm')
|
|
for _ in range(num_norms):
|
|
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
|
|
|
|
def forward(self,
|
|
query,
|
|
key=None,
|
|
value=None,
|
|
query_pos=None,
|
|
key_pos=None,
|
|
attn_masks=None,
|
|
query_key_padding_mask=None,
|
|
key_padding_mask=None,
|
|
**kwargs):
|
|
"""Forward function for `TransformerDecoderLayer`.
|
|
**kwargs contains some specific arguments of attentions.
|
|
Args:
|
|
query (Tensor): The input query with shape
|
|
[num_queries, bs, embed_dims] if
|
|
self.batch_first is False, else
|
|
[bs, num_queries embed_dims].
|
|
key (Tensor): The key tensor with shape [num_keys, bs,
|
|
embed_dims] if self.batch_first is False, else
|
|
[bs, num_keys, embed_dims] .
|
|
value (Tensor): The value tensor with same shape as `key`.
|
|
query_pos (Tensor): The positional encoding for `query`.
|
|
Default: None.
|
|
key_pos (Tensor): The positional encoding for `key`.
|
|
Default: None.
|
|
attn_masks (List[Tensor] | None): 2D Tensor used in
|
|
calculation of corresponding attention. The length of
|
|
it should equal to the number of `attention` in
|
|
`operation_order`. Default: None.
|
|
query_key_padding_mask (Tensor): ByteTensor for `query`, with
|
|
shape [bs, num_queries]. Only used in `self_attn` layer.
|
|
Defaults to None.
|
|
key_padding_mask (Tensor): ByteTensor for `query`, with
|
|
shape [bs, num_keys]. Default: None.
|
|
Returns:
|
|
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
|
|
"""
|
|
|
|
norm_index = 0
|
|
attn_index = 0
|
|
ffn_index = 0
|
|
identity = query
|
|
if attn_masks is None:
|
|
attn_masks = [None for _ in range(self.num_attn)]
|
|
elif isinstance(attn_masks, torch.Tensor):
|
|
attn_masks = [
|
|
copy.deepcopy(attn_masks) for _ in range(self.num_attn)
|
|
]
|
|
warnings.warn(f'Use same attn_mask in all attentions in '
|
|
f'{self.__class__.__name__} ')
|
|
else:
|
|
assert len(attn_masks) == self.num_attn, f'The length of ' \
|
|
f'attn_masks {len(attn_masks)} must be equal ' \
|
|
f'to the number of attention in ' \
|
|
f'operation_order {self.num_attn}'
|
|
|
|
for layer in self.operation_order:
|
|
if layer == 'self_attn':
|
|
temp_key = temp_value = query
|
|
query = self.attentions[attn_index](
|
|
query,
|
|
temp_key,
|
|
temp_value,
|
|
identity if self.pre_norm else None,
|
|
query_pos=query_pos,
|
|
key_pos=query_pos,
|
|
attn_mask=attn_masks[attn_index],
|
|
key_padding_mask=query_key_padding_mask,
|
|
**kwargs)
|
|
attn_index += 1
|
|
identity = query
|
|
|
|
elif layer == 'norm':
|
|
query = self.norms[norm_index](query)
|
|
norm_index += 1
|
|
|
|
elif layer == 'cross_attn':
|
|
query = self.attentions[attn_index](
|
|
query,
|
|
key,
|
|
value,
|
|
identity if self.pre_norm else None,
|
|
query_pos=query_pos,
|
|
key_pos=key_pos,
|
|
attn_mask=attn_masks[attn_index],
|
|
key_padding_mask=key_padding_mask,
|
|
**kwargs)
|
|
attn_index += 1
|
|
identity = query
|
|
|
|
elif layer == 'ffn':
|
|
query = self.ffns[ffn_index](
|
|
query, identity if self.pre_norm else None)
|
|
ffn_index += 1
|
|
|
|
return query
|
|
|
|
|
|
@TRANSFORMER_LAYER.register_module()
|
|
class BEVFormerLayer(MyCustomBaseTransformerLayer):
|
|
"""Implements decoder layer in DETR transformer.
|
|
Args:
|
|
attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )):
|
|
Configs for self_attention or cross_attention, the order
|
|
should be consistent with it in `operation_order`. If it is
|
|
a dict, it would be expand to the number of attention in
|
|
`operation_order`.
|
|
feedforward_channels (int): The hidden dimension for FFNs.
|
|
ffn_dropout (float): Probability of an element to be zeroed
|
|
in ffn. Default 0.0.
|
|
operation_order (tuple[str]): The execution order of operation
|
|
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
|
|
Default:None
|
|
act_cfg (dict): The activation config for FFNs. Default: `LN`
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: `LN`.
|
|
ffn_num_fcs (int): The number of fully-connected layers in FFNs.
|
|
Default:2.
|
|
"""
|
|
|
|
def __init__(self,
|
|
attn_cfgs,
|
|
feedforward_channels,
|
|
ffn_dropout=0.0,
|
|
operation_order=None,
|
|
act_cfg=dict(type='ReLU', inplace=True),
|
|
norm_cfg=dict(type='LN'),
|
|
ffn_num_fcs=2,
|
|
**kwargs):
|
|
super(BEVFormerLayer, self).__init__(
|
|
attn_cfgs=attn_cfgs,
|
|
feedforward_channels=feedforward_channels,
|
|
ffn_dropout=ffn_dropout,
|
|
operation_order=operation_order,
|
|
act_cfg=act_cfg,
|
|
norm_cfg=norm_cfg,
|
|
ffn_num_fcs=ffn_num_fcs,
|
|
**kwargs)
|
|
self.fp16_enabled = False
|
|
'''
|
|
assert len(operation_order) == 6
|
|
assert set(operation_order) == set(
|
|
['self_attn', 'norm', 'cross_attn', 'ffn'])
|
|
'''
|
|
|
|
def forward(self,
|
|
query,
|
|
key=None,
|
|
value=None,
|
|
bev_pos=None,
|
|
query_pos=None,
|
|
key_pos=None,
|
|
attn_masks=None,
|
|
query_key_padding_mask=None,
|
|
key_padding_mask=None,
|
|
ref_2d=None,
|
|
ref_3d=None,
|
|
bev_h=None,
|
|
bev_w=None,
|
|
reference_points_cam=None,
|
|
mask=None,
|
|
spatial_shapes=None,
|
|
level_start_index=None,
|
|
prev_bev=None,
|
|
**kwargs):
|
|
"""Forward function for `TransformerDecoderLayer`.
|
|
|
|
**kwargs contains some specific arguments of attentions.
|
|
|
|
Args:
|
|
query (Tensor): The input query with shape
|
|
[num_queries, bs, embed_dims] if
|
|
self.batch_first is False, else
|
|
[bs, num_queries embed_dims].
|
|
key (Tensor): The key tensor with shape [num_keys, bs,
|
|
embed_dims] if self.batch_first is False, else
|
|
[bs, num_keys, embed_dims] .
|
|
value (Tensor): The value tensor with same shape as `key`.
|
|
query_pos (Tensor): The positional encoding for `query`.
|
|
Default: None.
|
|
key_pos (Tensor): The positional encoding for `key`.
|
|
Default: None.
|
|
attn_masks (List[Tensor] | None): 2D Tensor used in
|
|
calculation of corresponding attention. The length of
|
|
it should equal to the number of `attention` in
|
|
`operation_order`. Default: None.
|
|
query_key_padding_mask (Tensor): ByteTensor for `query`, with
|
|
shape [bs, num_queries]. Only used in `self_attn` layer.
|
|
Defaults to None.
|
|
key_padding_mask (Tensor): ByteTensor for `query`, with
|
|
shape [bs, num_keys]. Default: None.
|
|
|
|
Returns:
|
|
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
|
|
"""
|
|
|
|
norm_index = 0
|
|
attn_index = 0
|
|
ffn_index = 0
|
|
identity = query
|
|
if attn_masks is None:
|
|
attn_masks = [None for _ in range(self.num_attn)]
|
|
elif isinstance(attn_masks, torch.Tensor):
|
|
attn_masks = [
|
|
copy.deepcopy(attn_masks) for _ in range(self.num_attn)
|
|
]
|
|
warnings.warn(f'Use same attn_mask in all attentions in '
|
|
f'{self.__class__.__name__} ')
|
|
else:
|
|
assert len(attn_masks) == self.num_attn, f'The length of ' \
|
|
f'attn_masks {len(attn_masks)} must be equal ' \
|
|
f'to the number of attention in ' \
|
|
f'operation_order {self.num_attn}'
|
|
|
|
for layer in self.operation_order:
|
|
|
|
if layer == 'self_attn':
|
|
|
|
query = self.attentions[attn_index](
|
|
query,
|
|
prev_bev,
|
|
prev_bev,
|
|
identity if self.pre_norm else None,
|
|
query_pos=bev_pos,
|
|
key_pos=bev_pos,
|
|
attn_mask=attn_masks[attn_index],
|
|
key_padding_mask=query_key_padding_mask,
|
|
reference_points=ref_2d,
|
|
spatial_shapes=torch.tensor(
|
|
[[bev_h, bev_w]], device=query.device),
|
|
level_start_index=torch.tensor([0], device=query.device),
|
|
**kwargs)
|
|
attn_index += 1
|
|
identity = query
|
|
|
|
elif layer == 'norm':
|
|
query = self.norms[norm_index](query)
|
|
norm_index += 1
|
|
|
|
|
|
elif layer == 'cross_attn':
|
|
query = self.attentions[attn_index](
|
|
query,
|
|
key,
|
|
value,
|
|
identity if self.pre_norm else None,
|
|
query_pos=query_pos,
|
|
key_pos=key_pos,
|
|
reference_points=ref_3d,
|
|
reference_points_cam=reference_points_cam,
|
|
mask=mask,
|
|
attn_mask=attn_masks[attn_index],
|
|
key_padding_mask=key_padding_mask,
|
|
spatial_shapes=spatial_shapes,
|
|
level_start_index=level_start_index,
|
|
**kwargs)
|
|
attn_index += 1
|
|
identity = query
|
|
|
|
elif layer == 'ffn':
|
|
query = self.ffns[ffn_index](
|
|
query, identity if self.pre_norm else None)
|
|
ffn_index += 1
|
|
|
|
return query
|
|
|