# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

# pyre-unsafe

from dataclasses import dataclass
from typing import Any, Optional
import torch

from detectron2.structures import BoxMode, Instances

from .utils import AnnotationsAccumulator


@dataclass
class PackedCseAnnotations:
    x_gt: torch.Tensor
    y_gt: torch.Tensor
    coarse_segm_gt: Optional[torch.Tensor]
    vertex_mesh_ids_gt: torch.Tensor
    vertex_ids_gt: torch.Tensor
    bbox_xywh_gt: torch.Tensor
    bbox_xywh_est: torch.Tensor
    point_bbox_with_dp_indices: torch.Tensor
    point_bbox_indices: torch.Tensor
    bbox_indices: torch.Tensor


class CseAnnotationsAccumulator(AnnotationsAccumulator):
    """
    Accumulates annotations by batches that correspond to objects detected on
    individual images. Can pack them together into single tensors.
    """

    def __init__(self):
        self.x_gt = []
        self.y_gt = []
        self.s_gt = []
        self.vertex_mesh_ids_gt = []
        self.vertex_ids_gt = []
        self.bbox_xywh_gt = []
        self.bbox_xywh_est = []
        self.point_bbox_with_dp_indices = []
        self.point_bbox_indices = []
        self.bbox_indices = []
        self.nxt_bbox_with_dp_index = 0
        self.nxt_bbox_index = 0

    def accumulate(self, instances_one_image: Instances):
        """
        Accumulate instances data for one image

        Args:
            instances_one_image (Instances): instances data to accumulate
        """
        boxes_xywh_est = BoxMode.convert(
            instances_one_image.proposal_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS
        )
        boxes_xywh_gt = BoxMode.convert(
            instances_one_image.gt_boxes.tensor.clone(), BoxMode.XYXY_ABS, BoxMode.XYWH_ABS
        )
        n_matches = len(boxes_xywh_gt)
        assert n_matches == len(
            boxes_xywh_est
        ), f"Got {len(boxes_xywh_est)} proposal boxes and {len(boxes_xywh_gt)} GT boxes"
        if not n_matches:
            # no detection - GT matches
            return
        if (
            not hasattr(instances_one_image, "gt_densepose")
            or instances_one_image.gt_densepose is None
        ):
            # no densepose GT for the detections, just increase the bbox index
            self.nxt_bbox_index += n_matches
            return
        for box_xywh_est, box_xywh_gt, dp_gt in zip(
            boxes_xywh_est, boxes_xywh_gt, instances_one_image.gt_densepose
        ):
            if (dp_gt is not None) and (len(dp_gt.x) > 0):
                # pyre-fixme[6]: For 1st argument expected `Tensor` but got `float`.
                # pyre-fixme[6]: For 2nd argument expected `Tensor` but got `float`.
                self._do_accumulate(box_xywh_gt, box_xywh_est, dp_gt)
            self.nxt_bbox_index += 1

    def _do_accumulate(self, box_xywh_gt: torch.Tensor, box_xywh_est: torch.Tensor, dp_gt: Any):
        """
        Accumulate instances data for one image, given that the data is not empty

        Args:
            box_xywh_gt (tensor): GT bounding box
            box_xywh_est (tensor): estimated bounding box
            dp_gt: GT densepose data with the following attributes:
             - x: normalized X coordinates
             - y: normalized Y coordinates
             - segm: tensor of size [S, S] with coarse segmentation
             -
        """
        self.x_gt.append(dp_gt.x)
        self.y_gt.append(dp_gt.y)
        if hasattr(dp_gt, "segm"):
            self.s_gt.append(dp_gt.segm.unsqueeze(0))
        self.vertex_ids_gt.append(dp_gt.vertex_ids)
        self.vertex_mesh_ids_gt.append(torch.full_like(dp_gt.vertex_ids, dp_gt.mesh_id))
        self.bbox_xywh_gt.append(box_xywh_gt.view(-1, 4))
        self.bbox_xywh_est.append(box_xywh_est.view(-1, 4))
        self.point_bbox_with_dp_indices.append(
            torch.full_like(dp_gt.vertex_ids, self.nxt_bbox_with_dp_index)
        )
        self.point_bbox_indices.append(torch.full_like(dp_gt.vertex_ids, self.nxt_bbox_index))
        self.bbox_indices.append(self.nxt_bbox_index)
        self.nxt_bbox_with_dp_index += 1

    def pack(self) -> Optional[PackedCseAnnotations]:
        """
        Pack data into tensors
        """
        if not len(self.x_gt):
            # TODO:
            # returning proper empty annotations would require
            # creating empty tensors of appropriate shape and
            # type on an appropriate device;
            # we return None so far to indicate empty annotations
            return None
        return PackedCseAnnotations(
            x_gt=torch.cat(self.x_gt, 0),
            y_gt=torch.cat(self.y_gt, 0),
            vertex_mesh_ids_gt=torch.cat(self.vertex_mesh_ids_gt, 0),
            vertex_ids_gt=torch.cat(self.vertex_ids_gt, 0),
            # ignore segmentation annotations, if not all the instances contain those
            coarse_segm_gt=(
                torch.cat(self.s_gt, 0) if len(self.s_gt) == len(self.bbox_xywh_gt) else None
            ),
            bbox_xywh_gt=torch.cat(self.bbox_xywh_gt, 0),
            bbox_xywh_est=torch.cat(self.bbox_xywh_est, 0),
            point_bbox_with_dp_indices=torch.cat(self.point_bbox_with_dp_indices, 0),
            point_bbox_indices=torch.cat(self.point_bbox_indices, 0),
            bbox_indices=torch.as_tensor(
                self.bbox_indices, dtype=torch.long, device=self.x_gt[0].device
            ),
        )