Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from typing import List, Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
try: | |
import mmpretrain | |
from mmpretrain.evaluation.metrics import Accuracy | |
except ImportError: | |
mmpretrain = None | |
from mmengine.model import BaseModule | |
from mmdet.registry import MODELS | |
from mmdet.structures import ReIDDataSample | |
from .fc_module import FcModule | |
class LinearReIDHead(BaseModule): | |
"""Linear head for re-identification. | |
Args: | |
num_fcs (int): Number of fcs. | |
in_channels (int): Number of channels in the input. | |
fc_channels (int): Number of channels in the fcs. | |
out_channels (int): Number of channels in the output. | |
norm_cfg (dict, optional): Configuration of normlization method | |
after fc. Defaults to None. | |
act_cfg (dict, optional): Configuration of activation method after fc. | |
Defaults to None. | |
num_classes (int, optional): Number of the identities. Default to None. | |
loss_cls (dict, optional): Cross entropy loss to train the ReID module. | |
Defaults to None. | |
loss_triplet (dict, optional): Triplet loss to train the ReID module. | |
Defaults to None. | |
topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``. | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Defaults to dict(type='Normal',layer='Linear', mean=0, std=0.01, | |
bias=0). | |
""" | |
def __init__(self, | |
num_fcs: int, | |
in_channels: int, | |
fc_channels: int, | |
out_channels: int, | |
norm_cfg: Optional[dict] = None, | |
act_cfg: Optional[dict] = None, | |
num_classes: Optional[int] = None, | |
loss_cls: Optional[dict] = None, | |
loss_triplet: Optional[dict] = None, | |
topk: Union[int, Tuple[int]] = (1, ), | |
init_cfg: Union[dict, List[dict]] = dict( | |
type='Normal', layer='Linear', mean=0, std=0.01, bias=0)): | |
if mmpretrain is None: | |
raise RuntimeError('Please run "pip install openmim" and ' | |
'run "mim install mmpretrain" to ' | |
'install mmpretrain first.') | |
super(LinearReIDHead, self).__init__(init_cfg=init_cfg) | |
assert isinstance(topk, (int, tuple)) | |
if isinstance(topk, int): | |
topk = (topk, ) | |
for _topk in topk: | |
assert _topk > 0, 'Top-k should be larger than 0' | |
self.topk = topk | |
if loss_cls is None: | |
if isinstance(num_classes, int): | |
warnings.warn('Since cross entropy is not set, ' | |
'the num_classes will be ignored.') | |
if loss_triplet is None: | |
raise ValueError('Please choose at least one loss in ' | |
'triplet loss and cross entropy loss.') | |
elif not isinstance(num_classes, int): | |
raise TypeError('The num_classes must be a current number, ' | |
'if there is cross entropy loss.') | |
self.loss_cls = MODELS.build(loss_cls) if loss_cls else None | |
self.loss_triplet = MODELS.build(loss_triplet) \ | |
if loss_triplet else None | |
self.num_fcs = num_fcs | |
self.in_channels = in_channels | |
self.fc_channels = fc_channels | |
self.out_channels = out_channels | |
self.norm_cfg = norm_cfg | |
self.act_cfg = act_cfg | |
self.num_classes = num_classes | |
self._init_layers() | |
def _init_layers(self): | |
"""Initialize fc layers.""" | |
self.fcs = nn.ModuleList() | |
for i in range(self.num_fcs): | |
in_channels = self.in_channels if i == 0 else self.fc_channels | |
self.fcs.append( | |
FcModule(in_channels, self.fc_channels, self.norm_cfg, | |
self.act_cfg)) | |
in_channels = self.in_channels if self.num_fcs == 0 else \ | |
self.fc_channels | |
self.fc_out = nn.Linear(in_channels, self.out_channels) | |
if self.loss_cls: | |
self.bn = nn.BatchNorm1d(self.out_channels) | |
self.classifier = nn.Linear(self.out_channels, self.num_classes) | |
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor: | |
"""The forward process.""" | |
# Multiple stage inputs are acceptable | |
# but only the last stage will be used. | |
feats = feats[-1] | |
for m in self.fcs: | |
feats = m(feats) | |
feats = self.fc_out(feats) | |
return feats | |
def loss(self, feats: Tuple[torch.Tensor], | |
data_samples: List[ReIDDataSample]) -> dict: | |
"""Calculate losses. | |
Args: | |
feats (tuple[Tensor]): The features extracted from the backbone. | |
data_samples (List[ReIDDataSample]): The annotation data of | |
every samples. | |
Returns: | |
dict: a dictionary of loss components | |
""" | |
# The part can be traced by torch.fx | |
feats = self(feats) | |
# The part can not be traced by torch.fx | |
losses = self.loss_by_feat(feats, data_samples) | |
return losses | |
def loss_by_feat(self, feats: torch.Tensor, | |
data_samples: List[ReIDDataSample]) -> dict: | |
"""Unpack data samples and compute loss.""" | |
losses = dict() | |
gt_label = torch.cat([i.gt_label.label for i in data_samples]) | |
gt_label = gt_label.to(feats.device) | |
if self.loss_triplet: | |
losses['triplet_loss'] = self.loss_triplet(feats, gt_label) | |
if self.loss_cls: | |
feats_bn = self.bn(feats) | |
cls_score = self.classifier(feats_bn) | |
losses['ce_loss'] = self.loss_cls(cls_score, gt_label) | |
acc = Accuracy.calculate(cls_score, gt_label, topk=self.topk) | |
losses.update( | |
{f'accuracy_top-{k}': a | |
for k, a in zip(self.topk, acc)}) | |
return losses | |
def predict( | |
self, | |
feats: Tuple[torch.Tensor], | |
data_samples: List[ReIDDataSample] = None) -> List[ReIDDataSample]: | |
"""Inference without augmentation. | |
Args: | |
feats (Tuple[Tensor]): The features extracted from the backbone. | |
Multiple stage inputs are acceptable but only the last stage | |
will be used. | |
data_samples (List[ReIDDataSample], optional): The annotation | |
data of every samples. If not None, set ``pred_label`` of | |
the input data samples. Defaults to None. | |
Returns: | |
List[ReIDDataSample]: A list of data samples which contains the | |
predicted results. | |
""" | |
# The part can be traced by torch.fx | |
feats = self(feats) | |
# The part can not be traced by torch.fx | |
data_samples = self.predict_by_feat(feats, data_samples) | |
return data_samples | |
def predict_by_feat( | |
self, | |
feats: torch.Tensor, | |
data_samples: List[ReIDDataSample] = None) -> List[ReIDDataSample]: | |
"""Add prediction features to data samples.""" | |
if data_samples is not None: | |
for data_sample, feat in zip(data_samples, feats): | |
data_sample.pred_feature = feat | |
else: | |
data_samples = [] | |
for feat in feats: | |
data_sample = ReIDDataSample() | |
data_sample.pred_feature = feat | |
data_samples.append(data_sample) | |
return data_samples | |