TTP / mmdet /models /reid /linear_reid_head.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
try:
import mmpretrain
from mmpretrain.evaluation.metrics import Accuracy
except ImportError:
mmpretrain = None
from mmengine.model import BaseModule
from mmdet.registry import MODELS
from mmdet.structures import ReIDDataSample
from .fc_module import FcModule
@MODELS.register_module()
class LinearReIDHead(BaseModule):
"""Linear head for re-identification.
Args:
num_fcs (int): Number of fcs.
in_channels (int): Number of channels in the input.
fc_channels (int): Number of channels in the fcs.
out_channels (int): Number of channels in the output.
norm_cfg (dict, optional): Configuration of normlization method
after fc. Defaults to None.
act_cfg (dict, optional): Configuration of activation method after fc.
Defaults to None.
num_classes (int, optional): Number of the identities. Default to None.
loss_cls (dict, optional): Cross entropy loss to train the ReID module.
Defaults to None.
loss_triplet (dict, optional): Triplet loss to train the ReID module.
Defaults to None.
topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to dict(type='Normal',layer='Linear', mean=0, std=0.01,
bias=0).
"""
def __init__(self,
num_fcs: int,
in_channels: int,
fc_channels: int,
out_channels: int,
norm_cfg: Optional[dict] = None,
act_cfg: Optional[dict] = None,
num_classes: Optional[int] = None,
loss_cls: Optional[dict] = None,
loss_triplet: Optional[dict] = None,
topk: Union[int, Tuple[int]] = (1, ),
init_cfg: Union[dict, List[dict]] = dict(
type='Normal', layer='Linear', mean=0, std=0.01, bias=0)):
if mmpretrain is None:
raise RuntimeError('Please run "pip install openmim" and '
'run "mim install mmpretrain" to '
'install mmpretrain first.')
super(LinearReIDHead, self).__init__(init_cfg=init_cfg)
assert isinstance(topk, (int, tuple))
if isinstance(topk, int):
topk = (topk, )
for _topk in topk:
assert _topk > 0, 'Top-k should be larger than 0'
self.topk = topk
if loss_cls is None:
if isinstance(num_classes, int):
warnings.warn('Since cross entropy is not set, '
'the num_classes will be ignored.')
if loss_triplet is None:
raise ValueError('Please choose at least one loss in '
'triplet loss and cross entropy loss.')
elif not isinstance(num_classes, int):
raise TypeError('The num_classes must be a current number, '
'if there is cross entropy loss.')
self.loss_cls = MODELS.build(loss_cls) if loss_cls else None
self.loss_triplet = MODELS.build(loss_triplet) \
if loss_triplet else None
self.num_fcs = num_fcs
self.in_channels = in_channels
self.fc_channels = fc_channels
self.out_channels = out_channels
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.num_classes = num_classes
self._init_layers()
def _init_layers(self):
"""Initialize fc layers."""
self.fcs = nn.ModuleList()
for i in range(self.num_fcs):
in_channels = self.in_channels if i == 0 else self.fc_channels
self.fcs.append(
FcModule(in_channels, self.fc_channels, self.norm_cfg,
self.act_cfg))
in_channels = self.in_channels if self.num_fcs == 0 else \
self.fc_channels
self.fc_out = nn.Linear(in_channels, self.out_channels)
if self.loss_cls:
self.bn = nn.BatchNorm1d(self.out_channels)
self.classifier = nn.Linear(self.out_channels, self.num_classes)
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The forward process."""
# Multiple stage inputs are acceptable
# but only the last stage will be used.
feats = feats[-1]
for m in self.fcs:
feats = m(feats)
feats = self.fc_out(feats)
return feats
def loss(self, feats: Tuple[torch.Tensor],
data_samples: List[ReIDDataSample]) -> dict:
"""Calculate losses.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
data_samples (List[ReIDDataSample]): The annotation data of
every samples.
Returns:
dict: a dictionary of loss components
"""
# The part can be traced by torch.fx
feats = self(feats)
# The part can not be traced by torch.fx
losses = self.loss_by_feat(feats, data_samples)
return losses
def loss_by_feat(self, feats: torch.Tensor,
data_samples: List[ReIDDataSample]) -> dict:
"""Unpack data samples and compute loss."""
losses = dict()
gt_label = torch.cat([i.gt_label.label for i in data_samples])
gt_label = gt_label.to(feats.device)
if self.loss_triplet:
losses['triplet_loss'] = self.loss_triplet(feats, gt_label)
if self.loss_cls:
feats_bn = self.bn(feats)
cls_score = self.classifier(feats_bn)
losses['ce_loss'] = self.loss_cls(cls_score, gt_label)
acc = Accuracy.calculate(cls_score, gt_label, topk=self.topk)
losses.update(
{f'accuracy_top-{k}': a
for k, a in zip(self.topk, acc)})
return losses
def predict(
self,
feats: Tuple[torch.Tensor],
data_samples: List[ReIDDataSample] = None) -> List[ReIDDataSample]:
"""Inference without augmentation.
Args:
feats (Tuple[Tensor]): The features extracted from the backbone.
Multiple stage inputs are acceptable but only the last stage
will be used.
data_samples (List[ReIDDataSample], optional): The annotation
data of every samples. If not None, set ``pred_label`` of
the input data samples. Defaults to None.
Returns:
List[ReIDDataSample]: A list of data samples which contains the
predicted results.
"""
# The part can be traced by torch.fx
feats = self(feats)
# The part can not be traced by torch.fx
data_samples = self.predict_by_feat(feats, data_samples)
return data_samples
def predict_by_feat(
self,
feats: torch.Tensor,
data_samples: List[ReIDDataSample] = None) -> List[ReIDDataSample]:
"""Add prediction features to data samples."""
if data_samples is not None:
for data_sample, feat in zip(data_samples, feats):
data_sample.pred_feature = feat
else:
data_samples = []
for feat in feats:
data_sample = ReIDDataSample()
data_sample.pred_feature = feat
data_samples.append(data_sample)
return data_samples