# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from unittest.mock import patch

import pytest
import torch
import torch.nn as nn

from mmcv.cnn.bricks import CONV_LAYERS, ConvModule, HSigmoid, HSwish
from mmcv.utils import TORCH_VERSION, digit_version


@CONV_LAYERS.register_module()
class ExampleConv(nn.Module):

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 norm_cfg=None):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.bias = bias
        self.norm_cfg = norm_cfg
        self.output_padding = (0, 0, 0)
        self.transposed = False

        self.conv0 = nn.Conv2d(in_channels, out_channels, kernel_size)
        self.init_weights()

    def forward(self, x):
        x = self.conv0(x)
        return x

    def init_weights(self):
        nn.init.constant_(self.conv0.weight, 0)


def test_conv_module():
    with pytest.raises(AssertionError):
        # conv_cfg must be a dict or None
        conv_cfg = 'conv'
        ConvModule(3, 8, 2, conv_cfg=conv_cfg)

    with pytest.raises(AssertionError):
        # norm_cfg must be a dict or None
        norm_cfg = 'norm'
        ConvModule(3, 8, 2, norm_cfg=norm_cfg)

    with pytest.raises(KeyError):
        # softmax is not supported
        act_cfg = dict(type='softmax')
        ConvModule(3, 8, 2, act_cfg=act_cfg)

    # conv + norm + act
    conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
    assert conv.with_activation
    assert hasattr(conv, 'activate')
    assert conv.with_norm
    assert hasattr(conv, 'norm')
    x = torch.rand(1, 3, 256, 256)
    output = conv(x)
    assert output.shape == (1, 8, 255, 255)

    # conv + act
    conv = ConvModule(3, 8, 2)
    assert conv.with_activation
    assert hasattr(conv, 'activate')
    assert not conv.with_norm
    assert conv.norm is None
    x = torch.rand(1, 3, 256, 256)
    output = conv(x)
    assert output.shape == (1, 8, 255, 255)

    # conv
    conv = ConvModule(3, 8, 2, act_cfg=None)
    assert not conv.with_norm
    assert conv.norm is None
    assert not conv.with_activation
    assert not hasattr(conv, 'activate')
    x = torch.rand(1, 3, 256, 256)
    output = conv(x)
    assert output.shape == (1, 8, 255, 255)

    # conv with its own `init_weights` method
    conv_module = ConvModule(
        3, 8, 2, conv_cfg=dict(type='ExampleConv'), act_cfg=None)
    assert torch.equal(conv_module.conv.conv0.weight, torch.zeros(8, 3, 2, 2))

    # with_spectral_norm=True
    conv = ConvModule(3, 8, 3, padding=1, with_spectral_norm=True)
    assert hasattr(conv.conv, 'weight_orig')
    output = conv(x)
    assert output.shape == (1, 8, 256, 256)

    # padding_mode='reflect'
    conv = ConvModule(3, 8, 3, padding=1, padding_mode='reflect')
    assert isinstance(conv.padding_layer, nn.ReflectionPad2d)
    output = conv(x)
    assert output.shape == (1, 8, 256, 256)

    # non-existing padding mode
    with pytest.raises(KeyError):
        conv = ConvModule(3, 8, 3, padding=1, padding_mode='non_exists')

    # leaky relu
    conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='LeakyReLU'))
    assert isinstance(conv.activate, nn.LeakyReLU)
    output = conv(x)
    assert output.shape == (1, 8, 256, 256)

    # tanh
    conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='Tanh'))
    assert isinstance(conv.activate, nn.Tanh)
    output = conv(x)
    assert output.shape == (1, 8, 256, 256)

    # Sigmoid
    conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='Sigmoid'))
    assert isinstance(conv.activate, nn.Sigmoid)
    output = conv(x)
    assert output.shape == (1, 8, 256, 256)

    # PReLU
    conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='PReLU'))
    assert isinstance(conv.activate, nn.PReLU)
    output = conv(x)
    assert output.shape == (1, 8, 256, 256)

    # HSwish
    conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSwish'))
    if (TORCH_VERSION == 'parrots'
            or digit_version(TORCH_VERSION) < digit_version('1.7')):
        assert isinstance(conv.activate, HSwish)
    else:
        assert isinstance(conv.activate, nn.Hardswish)

    output = conv(x)
    assert output.shape == (1, 8, 256, 256)

    # HSigmoid
    conv = ConvModule(3, 8, 3, padding=1, act_cfg=dict(type='HSigmoid'))
    assert isinstance(conv.activate, HSigmoid)
    output = conv(x)
    assert output.shape == (1, 8, 256, 256)


def test_bias():
    # bias: auto, without norm
    conv = ConvModule(3, 8, 2)
    assert conv.conv.bias is not None

    # bias: auto, with norm
    conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
    assert conv.conv.bias is None

    # bias: False, without norm
    conv = ConvModule(3, 8, 2, bias=False)
    assert conv.conv.bias is None

    # bias: True, with batch norm
    with pytest.warns(UserWarning) as record:
        ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type='BN'))
    assert len(record) == 1
    assert record[0].message.args[
        0] == 'Unnecessary conv bias before batch/instance norm'

    # bias: True, with instance norm
    with pytest.warns(UserWarning) as record:
        ConvModule(3, 8, 2, bias=True, norm_cfg=dict(type='IN'))
    assert len(record) == 1
    assert record[0].message.args[
        0] == 'Unnecessary conv bias before batch/instance norm'

    # bias: True, with other norm
    with pytest.warns(UserWarning) as record:
        norm_cfg = dict(type='GN', num_groups=1)
        ConvModule(3, 8, 2, bias=True, norm_cfg=norm_cfg)
        warnings.warn('No warnings')
    assert len(record) == 1
    assert record[0].message.args[0] == 'No warnings'


def conv_forward(self, x):
    return x + '_conv'


def bn_forward(self, x):
    return x + '_bn'


def relu_forward(self, x):
    return x + '_relu'


@patch('torch.nn.ReLU.forward', relu_forward)
@patch('torch.nn.BatchNorm2d.forward', bn_forward)
@patch('torch.nn.Conv2d.forward', conv_forward)
def test_order():

    with pytest.raises(AssertionError):
        # order must be a tuple
        order = ['conv', 'norm', 'act']
        ConvModule(3, 8, 2, order=order)

    with pytest.raises(AssertionError):
        # length of order must be 3
        order = ('conv', 'norm')
        ConvModule(3, 8, 2, order=order)

    with pytest.raises(AssertionError):
        # order must be an order of 'conv', 'norm', 'act'
        order = ('conv', 'norm', 'norm')
        ConvModule(3, 8, 2, order=order)

    with pytest.raises(AssertionError):
        # order must be an order of 'conv', 'norm', 'act'
        order = ('conv', 'norm', 'something')
        ConvModule(3, 8, 2, order=order)

    # ('conv', 'norm', 'act')
    conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
    out = conv('input')
    assert out == 'input_conv_bn_relu'

    # ('norm', 'conv', 'act')
    conv = ConvModule(
        3, 8, 2, norm_cfg=dict(type='BN'), order=('norm', 'conv', 'act'))
    out = conv('input')
    assert out == 'input_bn_conv_relu'

    # ('conv', 'norm', 'act'), activate=False
    conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
    out = conv('input', activate=False)
    assert out == 'input_conv_bn'

    # ('conv', 'norm', 'act'), activate=False
    conv = ConvModule(3, 8, 2, norm_cfg=dict(type='BN'))
    out = conv('input', norm=False)
    assert out == 'input_conv_relu'