Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List, Optional | |
import torch | |
try: | |
import mmpretrain | |
from mmpretrain.models.classifiers import ImageClassifier | |
except ImportError: | |
mmpretrain = None | |
ImageClassifier = object | |
from mmdet.registry import MODELS | |
from mmdet.structures import ReIDDataSample | |
class BaseReID(ImageClassifier): | |
"""Base model for re-identification.""" | |
def __init__(self, *args, **kwargs): | |
if mmpretrain is None: | |
raise RuntimeError('Please run "pip install openmim" and ' | |
'run "mim install mmpretrain" to ' | |
'install mmpretrain first.') | |
super().__init__(*args, **kwargs) | |
def forward(self, | |
inputs: torch.Tensor, | |
data_samples: Optional[List[ReIDDataSample]] = None, | |
mode: str = 'tensor'): | |
"""The unified entry for a forward process in both training and test. | |
The method should accept three modes: "tensor", "predict" and "loss": | |
- "tensor": Forward the whole network and return tensor or tuple of | |
tensor without any post-processing, same as a common nn.Module. | |
- "predict": Forward and return the predictions, which are fully | |
processed to a list of :obj:`ReIDDataSample`. | |
- "loss": Forward and return a dict of losses according to the given | |
inputs and data samples. | |
Note that this method doesn't handle neither back propagation nor | |
optimizer updating, which are done in the :meth:`train_step`. | |
Args: | |
inputs (torch.Tensor): The input tensor with shape | |
(N, C, H, W) or (N, T, C, H, W). | |
data_samples (List[ReIDDataSample], optional): The annotation | |
data of every sample. It's required if ``mode="loss"``. | |
Defaults to None. | |
mode (str): Return what kind of value. Defaults to 'tensor'. | |
Returns: | |
The return type depends on ``mode``. | |
- If ``mode="tensor"``, return a tensor or a tuple of tensor. | |
- If ``mode="predict"``, return a list of | |
:obj:`ReIDDataSample`. | |
- If ``mode="loss"``, return a dict of tensor. | |
""" | |
if len(inputs.size()) == 5: | |
assert inputs.size(0) == 1 | |
inputs = inputs[0] | |
return super().forward(inputs, data_samples, mode) | |