TTP / mmdet /models /reid /base_reid.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import torch
try:
import mmpretrain
from mmpretrain.models.classifiers import ImageClassifier
except ImportError:
mmpretrain = None
ImageClassifier = object
from mmdet.registry import MODELS
from mmdet.structures import ReIDDataSample
@MODELS.register_module()
class BaseReID(ImageClassifier):
"""Base model for re-identification."""
def __init__(self, *args, **kwargs):
if mmpretrain is None:
raise RuntimeError('Please run "pip install openmim" and '
'run "mim install mmpretrain" to '
'install mmpretrain first.')
super().__init__(*args, **kwargs)
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[ReIDDataSample]] = None,
mode: str = 'tensor'):
"""The unified entry for a forward process in both training and test.
The method should accept three modes: "tensor", "predict" and "loss":
- "tensor": Forward the whole network and return tensor or tuple of
tensor without any post-processing, same as a common nn.Module.
- "predict": Forward and return the predictions, which are fully
processed to a list of :obj:`ReIDDataSample`.
- "loss": Forward and return a dict of losses according to the given
inputs and data samples.
Note that this method doesn't handle neither back propagation nor
optimizer updating, which are done in the :meth:`train_step`.
Args:
inputs (torch.Tensor): The input tensor with shape
(N, C, H, W) or (N, T, C, H, W).
data_samples (List[ReIDDataSample], optional): The annotation
data of every sample. It's required if ``mode="loss"``.
Defaults to None.
mode (str): Return what kind of value. Defaults to 'tensor'.
Returns:
The return type depends on ``mode``.
- If ``mode="tensor"``, return a tensor or a tuple of tensor.
- If ``mode="predict"``, return a list of
:obj:`ReIDDataSample`.
- If ``mode="loss"``, return a dict of tensor.
"""
if len(inputs.size()) == 5:
assert inputs.size(0) == 1
inputs = inputs[0]
return super().forward(inputs, data_samples, mode)