# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple

import torch
from mmcv.ops import point_sample
from mmengine.structures import InstanceData
from torch import Tensor

from mmseg.registry import TASK_UTILS
from mmseg.utils import ConfigType, SampleList


def seg_data_to_instance_data(ignore_index: int,
                              batch_data_samples: SampleList):
    """Convert the paradigm of ground truth from semantic segmentation to
    instance segmentation.

    Args:
        ignore_index (int): The label index to be ignored.
        batch_data_samples (List[SegDataSample]): The Data
            Samples. It usually includes information such as
            `gt_sem_seg`.

    Returns:
        tuple[Tensor]: A tuple contains two lists.
            - batch_gt_instances (List[InstanceData]): Batch of
                gt_instance. It usually includes ``labels``, each is
                unique ground truth label id of images, with
                shape (num_gt, ) and ``masks``, each is ground truth
                masks of each instances of a image, shape (num_gt, h, w).
            - batch_img_metas (List[Dict]): List of image meta information.
    """
    batch_gt_instances = []

    for data_sample in batch_data_samples:
        gt_sem_seg = data_sample.gt_sem_seg.data
        classes = torch.unique(
            gt_sem_seg,
            sorted=False,
            return_inverse=False,
            return_counts=False)

        # remove ignored region
        gt_labels = classes[classes != ignore_index]

        masks = []
        for class_id in gt_labels:
            masks.append(gt_sem_seg == class_id)

        if len(masks) == 0:
            gt_masks = torch.zeros(
                (0, gt_sem_seg.shape[-2],
                 gt_sem_seg.shape[-1])).to(gt_sem_seg).long()
        else:
            gt_masks = torch.stack(masks).squeeze(1).long()

        instance_data = InstanceData(labels=gt_labels, masks=gt_masks)
        batch_gt_instances.append(instance_data)
    return batch_gt_instances


class MatchMasks:
    """Match the predictions to category labels.

    Args:
        num_points (int): the number of sampled points to compute cost.
        num_queries (int): the number of prediction masks.
        num_classes (int): the number of classes.
        assigner (BaseAssigner): the assigner to compute matching.
    """

    def __init__(self,
                 num_points: int,
                 num_queries: int,
                 num_classes: int,
                 assigner: ConfigType = None):
        assert assigner is not None, "\'assigner\' in decode_head.train_cfg" \
                                     'cannot be None'
        assert num_points > 0, 'num_points should be a positive integer.'
        self.num_points = num_points
        self.num_queries = num_queries
        self.num_classes = num_classes
        self.assigner = TASK_UTILS.build(assigner)

    def get_targets(self, cls_scores: List[Tensor], mask_preds: List[Tensor],
                    batch_gt_instances: List[InstanceData]) -> Tuple:
        """Compute best mask matches for all images for a decoder layer.

        Args:
            cls_scores (List[Tensor]): Mask score logits from a single
                decoder layer for all images. Each with shape (num_queries,
                cls_out_channels).
            mask_preds (List[Tensor]): Mask logits from a single decoder
                layer for all images. Each with shape (num_queries, h, w).
            batch_gt_instances (List[InstanceData]): each contains
                ``labels`` and ``masks``.

        Returns:
            tuple: a tuple containing the following targets.

                - labels (List[Tensor]): Labels of all images.\
                    Each with shape (num_queries, ).
                - mask_targets (List[Tensor]): Mask targets of\
                    all images. Each with shape (num_queries, h, w).
                - mask_weights (List[Tensor]): Mask weights of\
                    all images. Each with shape (num_queries, ).
                - avg_factor (int): Average factor that is used to
                    average the loss. `avg_factor` is usually equal
                    to the number of positive priors.
        """
        batch_size = cls_scores.shape[0]
        results = dict({
            'labels': [],
            'mask_targets': [],
            'mask_weights': [],
        })
        for i in range(batch_size):
            labels, mask_targets, mask_weights\
                = self._get_targets_single(cls_scores[i],
                                           mask_preds[i],
                                           batch_gt_instances[i])
            results['labels'].append(labels)
            results['mask_targets'].append(mask_targets)
            results['mask_weights'].append(mask_weights)

        # shape (batch_size, num_queries)
        labels = torch.stack(results['labels'], dim=0)
        # shape (batch_size, num_gts, h, w)
        mask_targets = torch.cat(results['mask_targets'], dim=0)
        # shape (batch_size, num_queries)
        mask_weights = torch.stack(results['mask_weights'], dim=0)

        avg_factor = sum(
            [len(gt_instances.labels) for gt_instances in batch_gt_instances])

        res = (labels, mask_targets, mask_weights, avg_factor)

        return res

    def _get_targets_single(self, cls_score: Tensor, mask_pred: Tensor,
                            gt_instances: InstanceData) \
            -> Tuple[Tensor, Tensor, Tensor]:
        """Compute a set of best mask matches for one image.

        Args:
            cls_score (Tensor): Mask score logits from a single decoder layer
                for one image. Shape (num_queries, cls_out_channels).
            mask_pred (Tensor): Mask logits for a single decoder layer for one
                image. Shape (num_queries, h, w).
            gt_instances (:obj:`InstanceData`): It contains ``labels`` and
                ``masks``.

        Returns:
            tuple[Tensor]: A tuple containing the following for one image.

                - labels (Tensor): Labels of each image. \
                    shape (num_queries, ).
                - mask_targets (Tensor): Mask targets of each image. \
                    shape (num_queries, h, w).
                - mask_weights (Tensor): Mask weights of each image. \
                    shape (num_queries, ).
        """
        gt_labels = gt_instances.labels
        gt_masks = gt_instances.masks
        # when "gt_labels" is empty, classify all queries to background
        if len(gt_labels) == 0:
            labels = gt_labels.new_full((self.num_queries, ),
                                        self.num_classes,
                                        dtype=torch.long)
            mask_targets = gt_labels
            mask_weights = gt_labels.new_zeros((self.num_queries, ))
            return labels, mask_targets, mask_weights
        # sample points
        num_queries = cls_score.shape[0]
        num_gts = gt_labels.shape[0]

        point_coords = torch.rand((1, self.num_points, 2),
                                  device=cls_score.device)
        # shape (num_queries, num_points)
        mask_points_pred = point_sample(
            mask_pred.unsqueeze(1), point_coords.repeat(num_queries, 1,
                                                        1)).squeeze(1)
        # shape (num_gts, num_points)
        gt_points_masks = point_sample(
            gt_masks.unsqueeze(1).float(), point_coords.repeat(num_gts, 1,
                                                               1)).squeeze(1)

        sampled_gt_instances = InstanceData(
            labels=gt_labels, masks=gt_points_masks)
        sampled_pred_instances = InstanceData(
            scores=cls_score, masks=mask_points_pred)
        # assign and sample
        matched_quiery_inds, matched_label_inds = self.assigner.assign(
            pred_instances=sampled_pred_instances,
            gt_instances=sampled_gt_instances)
        labels = gt_labels.new_full((self.num_queries, ),
                                    self.num_classes,
                                    dtype=torch.long)
        labels[matched_quiery_inds] = gt_labels[matched_label_inds]

        mask_weights = gt_labels.new_zeros((self.num_queries, ))
        mask_weights[matched_quiery_inds] = 1
        mask_targets = gt_masks[matched_label_inds]

        return labels, mask_targets, mask_weights