# Copyright (c) Facebook, Inc. and its affiliates.
# -*- coding: utf-8 -*-

import typing
from typing import Any, List
import fvcore
from fvcore.nn import activation_count, flop_count, parameter_count, parameter_count_table
from torch import nn

from detectron2.export import TracingAdapter

__all__ = [
    "activation_count_operators",
    "flop_count_operators",
    "parameter_count_table",
    "parameter_count",
    "FlopCountAnalysis",
]

FLOPS_MODE = "flops"
ACTIVATIONS_MODE = "activations"


# Some extra ops to ignore from counting, including elementwise and reduction ops
_IGNORED_OPS = {
    "aten::add",
    "aten::add_",
    "aten::argmax",
    "aten::argsort",
    "aten::batch_norm",
    "aten::constant_pad_nd",
    "aten::div",
    "aten::div_",
    "aten::exp",
    "aten::log2",
    "aten::max_pool2d",
    "aten::meshgrid",
    "aten::mul",
    "aten::mul_",
    "aten::neg",
    "aten::nonzero_numpy",
    "aten::reciprocal",
    "aten::repeat_interleave",
    "aten::rsub",
    "aten::sigmoid",
    "aten::sigmoid_",
    "aten::softmax",
    "aten::sort",
    "aten::sqrt",
    "aten::sub",
    "torchvision::nms",  # TODO estimate flop for nms
}


class FlopCountAnalysis(fvcore.nn.FlopCountAnalysis):
    """
    Same as :class:`fvcore.nn.FlopCountAnalysis`, but supports detectron2 models.
    """

    def __init__(self, model, inputs):
        """
        Args:
            model (nn.Module):
            inputs (Any): inputs of the given model. Does not have to be tuple of tensors.
        """
        wrapper = TracingAdapter(model, inputs, allow_non_tensor=True)
        super().__init__(wrapper, wrapper.flattened_inputs)
        self.set_op_handle(**{k: None for k in _IGNORED_OPS})


def flop_count_operators(model: nn.Module, inputs: list) -> typing.DefaultDict[str, float]:
    """
    Implement operator-level flops counting using jit.
    This is a wrapper of :func:`fvcore.nn.flop_count` and adds supports for standard
    detection models in detectron2.
    Please use :class:`FlopCountAnalysis` for more advanced functionalities.

    Note:
        The function runs the input through the model to compute flops.
        The flops of a detection model is often input-dependent, for example,
        the flops of box & mask head depends on the number of proposals &
        the number of detected objects.
        Therefore, the flops counting using a single input may not accurately
        reflect the computation cost of a model. It's recommended to average
        across a number of inputs.

    Args:
        model: a detectron2 model that takes `list[dict]` as input.
        inputs (list[dict]): inputs to model, in detectron2's standard format.
            Only "image" key will be used.
        supported_ops (dict[str, Handle]): see documentation of :func:`fvcore.nn.flop_count`

    Returns:
        Counter: Gflop count per operator
    """
    old_train = model.training
    model.eval()
    ret = FlopCountAnalysis(model, inputs).by_operator()
    model.train(old_train)
    return {k: v / 1e9 for k, v in ret.items()}


def activation_count_operators(
    model: nn.Module, inputs: list, **kwargs
) -> typing.DefaultDict[str, float]:
    """
    Implement operator-level activations counting using jit.
    This is a wrapper of fvcore.nn.activation_count, that supports standard detection models
    in detectron2.

    Note:
        The function runs the input through the model to compute activations.
        The activations of a detection model is often input-dependent, for example,
        the activations of box & mask head depends on the number of proposals &
        the number of detected objects.

    Args:
        model: a detectron2 model that takes `list[dict]` as input.
        inputs (list[dict]): inputs to model, in detectron2's standard format.
            Only "image" key will be used.

    Returns:
        Counter: activation count per operator
    """
    return _wrapper_count_operators(model=model, inputs=inputs, mode=ACTIVATIONS_MODE, **kwargs)


def _wrapper_count_operators(
    model: nn.Module, inputs: list, mode: str, **kwargs
) -> typing.DefaultDict[str, float]:
    # ignore some ops
    supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS}
    supported_ops.update(kwargs.pop("supported_ops", {}))
    kwargs["supported_ops"] = supported_ops

    assert len(inputs) == 1, "Please use batch size=1"
    tensor_input = inputs[0]["image"]
    inputs = [{"image": tensor_input}]  # remove other keys, in case there are any

    old_train = model.training
    if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
        model = model.module
    wrapper = TracingAdapter(model, inputs)
    wrapper.eval()
    if mode == FLOPS_MODE:
        ret = flop_count(wrapper, (tensor_input,), **kwargs)
    elif mode == ACTIVATIONS_MODE:
        ret = activation_count(wrapper, (tensor_input,), **kwargs)
    else:
        raise NotImplementedError("Count for mode {} is not supported yet.".format(mode))
    # compatible with change in fvcore
    if isinstance(ret, tuple):
        ret = ret[0]
    model.train(old_train)
    return ret


def find_unused_parameters(model: nn.Module, inputs: Any) -> List[str]:
    """
    Given a model, find parameters that do not contribute
    to the loss.

    Args:
        model: a model in training mode that returns losses
        inputs: argument or a tuple of arguments. Inputs of the model

    Returns:
        list[str]: the name of unused parameters
    """
    assert model.training
    for _, prm in model.named_parameters():
        prm.grad = None

    if isinstance(inputs, tuple):
        losses = model(*inputs)
    else:
        losses = model(inputs)

    if isinstance(losses, dict):
        losses = sum(losses.values())
    losses.backward()

    unused: List[str] = []
    for name, prm in model.named_parameters():
        if prm.grad is None:
            unused.append(name)
        prm.grad = None
    return unused