import logging

import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision.models.feature_extraction import create_feature_extractor

from .base import BaseModel
from .schema import ResNetConfiguration

logger = logging.getLogger(__name__)


class DecoderBlock(nn.Module):
    def __init__(
        self, previous, out, ksize=3, num_convs=1, norm=nn.BatchNorm2d, padding="zeros"
    ):
        super().__init__()
        layers = []
        for i in range(num_convs):
            conv = nn.Conv2d(
                previous if i == 0 else out,
                out,
                kernel_size=ksize,
                padding=ksize // 2,
                bias=norm is None,
                padding_mode=padding,
            )
            layers.append(conv)
            if norm is not None:
                layers.append(norm(out))
            layers.append(nn.ReLU(inplace=True))
        self.layers = nn.Sequential(*layers)

    def forward(self, previous, skip):
        _, _, hp, wp = previous.shape
        _, _, hs, ws = skip.shape
        scale = 2 ** np.round(np.log2(np.array([hs / hp, ws / wp])))
        upsampled = nn.functional.interpolate(
            previous, scale_factor=scale.tolist(), mode="bilinear", align_corners=False
        )
        # If the shape of the input map `skip` is not a multiple of 2,
        # it will not match the shape of the upsampled map `upsampled`.
        # If the downsampling uses ceil_mode=False, we nedd to crop `skip`.
        # If it uses ceil_mode=True (not supported here), we should pad it.
        _, _, hu, wu = upsampled.shape
        _, _, hs, ws = skip.shape
        if (hu <= hs) and (wu <= ws):
            skip = skip[:, :, :hu, :wu]
        elif (hu >= hs) and (wu >= ws):
            skip = nn.functional.pad(skip, [0, wu - ws, 0, hu - hs])
        else:
            raise ValueError(
                f"Inconsistent skip vs upsampled shapes: {(hs, ws)}, {(hu, wu)}"
            )

        return self.layers(skip) + upsampled


class FPN(nn.Module):
    def __init__(self, in_channels_list, out_channels, **kw):
        super().__init__()
        self.first = nn.Conv2d(
            in_channels_list[-1], out_channels, 1, padding=0, bias=True
        )
        self.blocks = nn.ModuleList(
            [
                DecoderBlock(c, out_channels, ksize=1, **kw)
                for c in in_channels_list[::-1][1:]
            ]
        )
        self.out = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, layers):
        feats = None
        for idx, x in enumerate(reversed(layers.values())):
            if feats is None:
                feats = self.first(x)
            else:
                feats = self.blocks[idx - 1](feats, x)
        out = self.out(feats)
        return out


def remove_conv_stride(conv):
    conv_new = nn.Conv2d(
        conv.in_channels,
        conv.out_channels,
        conv.kernel_size,
        bias=conv.bias is not None,
        stride=1,
        padding=conv.padding,
    )
    conv_new.weight = conv.weight
    conv_new.bias = conv.bias
    return conv_new


class FeatureExtractor(BaseModel):
    default_conf = {
        "pretrained": True,
        "input_dim": 3,
        "output_dim": 128,  # # of channels in output feature maps
        "encoder": "resnet50",  # torchvision net as string
        "remove_stride_from_first_conv": False,
        "num_downsample": None,  # how many downsample block
        "decoder_norm": "nn.BatchNorm2d",  # normalization ind decoder blocks
        "do_average_pooling": False,
        "checkpointed": False,  # whether to use gradient checkpointing
    }
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    def freeze_encoder(self):
        """
        Freeze the encoder part of the model, i.e., set requires_grad = False
        for all parameters in the encoder.
        """
        for param in self.encoder.parameters():
            param.requires_grad = False
        logger.debug("Encoder has been frozen.")

    def unfreeze_encoder(self):
        """
        Unfreeze the encoder part of the model, i.e., set requires_grad = True
        for all parameters in the encoder.
        """
        for param in self.encoder.parameters():
            param.requires_grad = True
        logger.debug("Encoder has been unfrozen.")

    def build_encoder(self, conf: ResNetConfiguration):
        assert isinstance(conf.encoder, str)
        if conf.pretrained:
            assert conf.input_dim == 3
        Encoder = getattr(torchvision.models, conf.encoder)

        kw = {}
        if conf.encoder.startswith("resnet"):
            layers = ["relu", "layer1", "layer2", "layer3", "layer4"]
            kw["replace_stride_with_dilation"] = [False, False, False]
        elif conf.encoder == "vgg13":
            layers = [
                "features.3",
                "features.8",
                "features.13",
                "features.18",
                "features.23",
            ]
        elif conf.encoder == "vgg16":
            layers = [
                "features.3",
                "features.8",
                "features.15",
                "features.22",
                "features.29",
            ]
        else:
            raise NotImplementedError(conf.encoder)

        if conf.num_downsample is not None:
            layers = layers[: conf.num_downsample]
        encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw)
        encoder = create_feature_extractor(encoder, return_nodes=layers)
        if conf.encoder.startswith("resnet") and conf.remove_stride_from_first_conv:
            encoder.conv1 = remove_conv_stride(encoder.conv1)

        if conf.do_average_pooling:
            raise NotImplementedError
        if conf.checkpointed:
            raise NotImplementedError

        return encoder, layers

    def _init(self, conf):
        # Preprocessing
        self.register_buffer("mean_", torch.tensor(self.mean), persistent=False)
        self.register_buffer("std_", torch.tensor(self.std), persistent=False)

        # Encoder
        self.encoder, self.layers = self.build_encoder(conf)
        s = 128
        inp = torch.zeros(1, 3, s, s)
        features = list(self.encoder(inp).values())
        self.skip_dims = [x.shape[1] for x in features]
        self.layer_strides = [s / f.shape[-1] for f in features]
        self.scales = [self.layer_strides[0]]

        # Decoder
        norm = eval(conf.decoder_norm) if conf.decoder_norm else None  # noqa
        self.decoder = FPN(self.skip_dims, out_channels=conf.output_dim, norm=norm)

        logger.debug(
            "Built feature extractor with layers {name:dim:stride}:\n"
            f"{list(zip(self.layers, self.skip_dims, self.layer_strides))}\n"
            f"and output scales {self.scales}."
        )

    def _forward(self, data):
        image = data["image"]
        image = (image - self.mean_[:, None, None]) / self.std_[:, None, None]

        skip_features = self.encoder(image)
        output = self.decoder(skip_features)
        return output, data['camera']