Spaces:
Runtime error
Runtime error
File size: 7,665 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
try:
import mmpretrain
from mmpretrain.evaluation.metrics import Accuracy
except ImportError:
mmpretrain = None
from mmengine.model import BaseModule
from mmdet.registry import MODELS
from mmdet.structures import ReIDDataSample
from .fc_module import FcModule
@MODELS.register_module()
class LinearReIDHead(BaseModule):
"""Linear head for re-identification.
Args:
num_fcs (int): Number of fcs.
in_channels (int): Number of channels in the input.
fc_channels (int): Number of channels in the fcs.
out_channels (int): Number of channels in the output.
norm_cfg (dict, optional): Configuration of normlization method
after fc. Defaults to None.
act_cfg (dict, optional): Configuration of activation method after fc.
Defaults to None.
num_classes (int, optional): Number of the identities. Default to None.
loss_cls (dict, optional): Cross entropy loss to train the ReID module.
Defaults to None.
loss_triplet (dict, optional): Triplet loss to train the ReID module.
Defaults to None.
topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to dict(type='Normal',layer='Linear', mean=0, std=0.01,
bias=0).
"""
def __init__(self,
num_fcs: int,
in_channels: int,
fc_channels: int,
out_channels: int,
norm_cfg: Optional[dict] = None,
act_cfg: Optional[dict] = None,
num_classes: Optional[int] = None,
loss_cls: Optional[dict] = None,
loss_triplet: Optional[dict] = None,
topk: Union[int, Tuple[int]] = (1, ),
init_cfg: Union[dict, List[dict]] = dict(
type='Normal', layer='Linear', mean=0, std=0.01, bias=0)):
if mmpretrain is None:
raise RuntimeError('Please run "pip install openmim" and '
'run "mim install mmpretrain" to '
'install mmpretrain first.')
super(LinearReIDHead, self).__init__(init_cfg=init_cfg)
assert isinstance(topk, (int, tuple))
if isinstance(topk, int):
topk = (topk, )
for _topk in topk:
assert _topk > 0, 'Top-k should be larger than 0'
self.topk = topk
if loss_cls is None:
if isinstance(num_classes, int):
warnings.warn('Since cross entropy is not set, '
'the num_classes will be ignored.')
if loss_triplet is None:
raise ValueError('Please choose at least one loss in '
'triplet loss and cross entropy loss.')
elif not isinstance(num_classes, int):
raise TypeError('The num_classes must be a current number, '
'if there is cross entropy loss.')
self.loss_cls = MODELS.build(loss_cls) if loss_cls else None
self.loss_triplet = MODELS.build(loss_triplet) \
if loss_triplet else None
self.num_fcs = num_fcs
self.in_channels = in_channels
self.fc_channels = fc_channels
self.out_channels = out_channels
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.num_classes = num_classes
self._init_layers()
def _init_layers(self):
"""Initialize fc layers."""
self.fcs = nn.ModuleList()
for i in range(self.num_fcs):
in_channels = self.in_channels if i == 0 else self.fc_channels
self.fcs.append(
FcModule(in_channels, self.fc_channels, self.norm_cfg,
self.act_cfg))
in_channels = self.in_channels if self.num_fcs == 0 else \
self.fc_channels
self.fc_out = nn.Linear(in_channels, self.out_channels)
if self.loss_cls:
self.bn = nn.BatchNorm1d(self.out_channels)
self.classifier = nn.Linear(self.out_channels, self.num_classes)
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The forward process."""
# Multiple stage inputs are acceptable
# but only the last stage will be used.
feats = feats[-1]
for m in self.fcs:
feats = m(feats)
feats = self.fc_out(feats)
return feats
def loss(self, feats: Tuple[torch.Tensor],
data_samples: List[ReIDDataSample]) -> dict:
"""Calculate losses.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
data_samples (List[ReIDDataSample]): The annotation data of
every samples.
Returns:
dict: a dictionary of loss components
"""
# The part can be traced by torch.fx
feats = self(feats)
# The part can not be traced by torch.fx
losses = self.loss_by_feat(feats, data_samples)
return losses
def loss_by_feat(self, feats: torch.Tensor,
data_samples: List[ReIDDataSample]) -> dict:
"""Unpack data samples and compute loss."""
losses = dict()
gt_label = torch.cat([i.gt_label.label for i in data_samples])
gt_label = gt_label.to(feats.device)
if self.loss_triplet:
losses['triplet_loss'] = self.loss_triplet(feats, gt_label)
if self.loss_cls:
feats_bn = self.bn(feats)
cls_score = self.classifier(feats_bn)
losses['ce_loss'] = self.loss_cls(cls_score, gt_label)
acc = Accuracy.calculate(cls_score, gt_label, topk=self.topk)
losses.update(
{f'accuracy_top-{k}': a
for k, a in zip(self.topk, acc)})
return losses
def predict(
self,
feats: Tuple[torch.Tensor],
data_samples: List[ReIDDataSample] = None) -> List[ReIDDataSample]:
"""Inference without augmentation.
Args:
feats (Tuple[Tensor]): The features extracted from the backbone.
Multiple stage inputs are acceptable but only the last stage
will be used.
data_samples (List[ReIDDataSample], optional): The annotation
data of every samples. If not None, set ``pred_label`` of
the input data samples. Defaults to None.
Returns:
List[ReIDDataSample]: A list of data samples which contains the
predicted results.
"""
# The part can be traced by torch.fx
feats = self(feats)
# The part can not be traced by torch.fx
data_samples = self.predict_by_feat(feats, data_samples)
return data_samples
def predict_by_feat(
self,
feats: torch.Tensor,
data_samples: List[ReIDDataSample] = None) -> List[ReIDDataSample]:
"""Add prediction features to data samples."""
if data_samples is not None:
for data_sample, feat in zip(data_samples, feats):
data_sample.pred_feature = feat
else:
data_samples = []
for feat in feats:
data_sample = ReIDDataSample()
data_sample.pred_feature = feat
data_samples.append(data_sample)
return data_samples
|