# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn

try:
    import mmpretrain
    from mmpretrain.evaluation.metrics import Accuracy
except ImportError:
    mmpretrain = None

from mmengine.model import BaseModule

from mmdet.registry import MODELS
from mmdet.structures import ReIDDataSample
from .fc_module import FcModule


@MODELS.register_module()
class LinearReIDHead(BaseModule):
    """Linear head for re-identification.

    Args:
        num_fcs (int): Number of fcs.
        in_channels (int): Number of channels in the input.
        fc_channels (int): Number of channels in the fcs.
        out_channels (int): Number of channels in the output.
        norm_cfg (dict, optional): Configuration of normlization method
            after fc. Defaults to None.
        act_cfg (dict, optional): Configuration of activation method after fc.
            Defaults to None.
        num_classes (int, optional): Number of the identities. Default to None.
        loss_cls (dict, optional): Cross entropy loss to train the ReID module.
            Defaults to None.
        loss_triplet (dict, optional): Triplet loss to train the ReID module.
            Defaults to None.
        topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``.
        init_cfg (dict or list[dict], optional): Initialization config dict.
            Defaults to dict(type='Normal',layer='Linear', mean=0, std=0.01,
            bias=0).
    """

    def __init__(self,
                 num_fcs: int,
                 in_channels: int,
                 fc_channels: int,
                 out_channels: int,
                 norm_cfg: Optional[dict] = None,
                 act_cfg: Optional[dict] = None,
                 num_classes: Optional[int] = None,
                 loss_cls: Optional[dict] = None,
                 loss_triplet: Optional[dict] = None,
                 topk: Union[int, Tuple[int]] = (1, ),
                 init_cfg: Union[dict, List[dict]] = dict(
                     type='Normal', layer='Linear', mean=0, std=0.01, bias=0)):
        if mmpretrain is None:
            raise RuntimeError('Please run "pip install openmim" and '
                               'run "mim install mmpretrain" to '
                               'install mmpretrain first.')
        super(LinearReIDHead, self).__init__(init_cfg=init_cfg)

        assert isinstance(topk, (int, tuple))
        if isinstance(topk, int):
            topk = (topk, )
        for _topk in topk:
            assert _topk > 0, 'Top-k should be larger than 0'
        self.topk = topk

        if loss_cls is None:
            if isinstance(num_classes, int):
                warnings.warn('Since cross entropy is not set, '
                              'the num_classes will be ignored.')
            if loss_triplet is None:
                raise ValueError('Please choose at least one loss in '
                                 'triplet loss and cross entropy loss.')
        elif not isinstance(num_classes, int):
            raise TypeError('The num_classes must be a current number, '
                            'if there is cross entropy loss.')
        self.loss_cls = MODELS.build(loss_cls) if loss_cls else None
        self.loss_triplet = MODELS.build(loss_triplet) \
            if loss_triplet else None

        self.num_fcs = num_fcs
        self.in_channels = in_channels
        self.fc_channels = fc_channels
        self.out_channels = out_channels
        self.norm_cfg = norm_cfg
        self.act_cfg = act_cfg
        self.num_classes = num_classes

        self._init_layers()

    def _init_layers(self):
        """Initialize fc layers."""
        self.fcs = nn.ModuleList()
        for i in range(self.num_fcs):
            in_channels = self.in_channels if i == 0 else self.fc_channels
            self.fcs.append(
                FcModule(in_channels, self.fc_channels, self.norm_cfg,
                         self.act_cfg))
        in_channels = self.in_channels if self.num_fcs == 0 else \
            self.fc_channels
        self.fc_out = nn.Linear(in_channels, self.out_channels)
        if self.loss_cls:
            self.bn = nn.BatchNorm1d(self.out_channels)
            self.classifier = nn.Linear(self.out_channels, self.num_classes)

    def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
        """The forward process."""
        # Multiple stage inputs are acceptable
        # but only the last stage will be used.
        feats = feats[-1]

        for m in self.fcs:
            feats = m(feats)
        feats = self.fc_out(feats)
        return feats

    def loss(self, feats: Tuple[torch.Tensor],
             data_samples: List[ReIDDataSample]) -> dict:
        """Calculate losses.

        Args:
            feats (tuple[Tensor]): The features extracted from the backbone.
            data_samples (List[ReIDDataSample]): The annotation data of
                every samples.

        Returns:
            dict: a dictionary of loss components
        """
        # The part can be traced by torch.fx
        feats = self(feats)

        # The part can not be traced by torch.fx
        losses = self.loss_by_feat(feats, data_samples)
        return losses

    def loss_by_feat(self, feats: torch.Tensor,
                     data_samples: List[ReIDDataSample]) -> dict:
        """Unpack data samples and compute loss."""
        losses = dict()
        gt_label = torch.cat([i.gt_label.label for i in data_samples])
        gt_label = gt_label.to(feats.device)

        if self.loss_triplet:
            losses['triplet_loss'] = self.loss_triplet(feats, gt_label)

        if self.loss_cls:
            feats_bn = self.bn(feats)
            cls_score = self.classifier(feats_bn)
            losses['ce_loss'] = self.loss_cls(cls_score, gt_label)
            acc = Accuracy.calculate(cls_score, gt_label, topk=self.topk)
            losses.update(
                {f'accuracy_top-{k}': a
                 for k, a in zip(self.topk, acc)})

        return losses

    def predict(
            self,
            feats: Tuple[torch.Tensor],
            data_samples: List[ReIDDataSample] = None) -> List[ReIDDataSample]:
        """Inference without augmentation.

        Args:
            feats (Tuple[Tensor]): The features extracted from the backbone.
                Multiple stage inputs are acceptable but only the last stage
                will be used.
            data_samples (List[ReIDDataSample], optional): The annotation
                data of every samples. If not None, set ``pred_label`` of
                the input data samples. Defaults to None.

        Returns:
            List[ReIDDataSample]: A list of data samples which contains the
            predicted results.
        """
        # The part can be traced by torch.fx
        feats = self(feats)

        # The part can not be traced by torch.fx
        data_samples = self.predict_by_feat(feats, data_samples)

        return data_samples

    def predict_by_feat(
            self,
            feats: torch.Tensor,
            data_samples: List[ReIDDataSample] = None) -> List[ReIDDataSample]:
        """Add prediction features to data samples."""
        if data_samples is not None:
            for data_sample, feat in zip(data_samples, feats):
                data_sample.pred_feature = feat
        else:
            data_samples = []
            for feat in feats:
                data_sample = ReIDDataSample()
                data_sample.pred_feature = feat
                data_samples.append(data_sample)

        return data_samples