import os

from collections import OrderedDict

import torch

import transformers

import torch.nn.functional as F

from torch import nn

from torchvision.models import detection

from backbones import get_backbone

from embeddings import Box8PositionEmbedding2D

EPS = 1e-5

TRANSFORMER_MODEL = 'bert-base-uncased'
# TRANSFORMER_MODEL = 'distilroberta-base'


def get_tokenizer(cache=None):
    if cache is None:
        return transformers.BertTokenizer.from_pretrained(TRANSFORMER_MODEL)

    model_path = os.path.join(cache, TRANSFORMER_MODEL)
    os.makedirs(model_path, exist_ok=True)

    if os.path.exists(os.path.join(model_path, 'config.json')):
        return transformers.BertTokenizer.from_pretrained(model_path)

    tokenizer = transformers.BertTokenizer.from_pretrained(TRANSFORMER_MODEL)
    tokenizer.save_pretrained(model_path)

    return tokenizer


def weight_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_normal_(m.weight, gain=nn.init.calculate_gain('relu'))
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    elif isinstance(m, nn.Embedding):
        nn.init.xavier_normal_(m.weight)


class ImageEncoder(nn.Module):
    def __init__(self, backbone='resnet50', out_channels=256, pretrained=True,
                 freeze_pretrained=False, with_pos=True):
        super().__init__()

        model = get_backbone(backbone, pretrained)

        if pretrained and freeze_pretrained:
            for p in model.parameters():
                p.requires_grad = False

        if 'resnet' in backbone:
            self.backbone = detection.backbone_utils.IntermediateLayerGetter(
                model, return_layers=OrderedDict({'layer4': 'output'})
            )
            channels = 512 if backbone in ('resnet18', 'resnet34') else 2048

        elif backbone in ('cspdarknet53', 'efficientnet-b0', 'efficientnet-b3'):
            output_layer_name = list(model.named_children())[-1][0]
            self.backbone = detection.backbone_utils.IntermediateLayerGetter(
                model, return_layers=OrderedDict({output_layer_name: 'output'})
            )
            channels = {
                'cspdarknet53': 1024,
                'efficientnet-b0': 1280,
                'efficientnet-b3': 1536
            }[backbone]

        else:
            raise RuntimeError('not a valid backbone')

        in_channels = channels+8 if with_pos else channels

        self.proj = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, (1, 1), 1, bias=False),
            nn.GroupNorm(1, out_channels, eps=EPS),
            # nn.ReLU(inplace=True),
        )
        self.proj.apply(weight_init)

        self.pos_emb = None
        if with_pos:
            self.pos_emb = Box8PositionEmbedding2D(with_projection=False)

        self.out_channels = out_channels

    def forward(self, img, mask=None):
        x = self.backbone(img)['output']
        if self.pos_emb is not None:
            x = torch.cat([x, self.pos_emb(x)], dim=1)
        x = self.proj(x)  # NxDxHxW

        x_mask = None
        if mask is not None:
            _, _, H, W = x.size()
            x_mask = F.interpolate(mask, (H, W), mode='bilinear')
            x_mask = (x_mask > 0.5).long()

        return x, x_mask


class FPNImageEncoder(nn.Module):
    def __init__(self,
                 backbone='resnet50', out_channels=256, pretrained=True,
                 freeze_pretrained=False, with_pos=True):
        super().__init__()

        model = get_backbone(backbone, pretrained)

        if pretrained and freeze_pretrained:
            for p in model.parameters():
                p.requires_grad = False

        if 'resnet' in backbone:
            if backbone in ('resnet18', 'resnet34'):
                in_channels_list = [64, 128, 256, 512]
            else:
                in_channels_list = [256, 512, 1024, 2048]
            return_layers = OrderedDict({
                'layer1': '0', 'layer2': '1', 'layer3': '2', 'layer4': '3'
            })

        # elif backbone == 'cspdarknet53':
        #     in_channels_list = [128, 256, 512, 1024]
        #     return_layers = OrderedDict({
        #         '1':'0', '2':'1', '3':'2', '4':'3'
        #     })

        else:
            raise RuntimeError('not a valid backbone')

        self.backbone = model

        self.fpn = detection.backbone_utils.BackboneWithFPN(
            backbone=self.backbone,
            return_layers=return_layers,
            in_channels_list=in_channels_list,
            out_channels=out_channels
        )

        self.fpn.fpn.extra_blocks = None   # removes the 'pool' layer added by default

        self.out_channels = out_channels

        in_channels = int(out_channels + float(with_pos) * 8)

        self.proj = nn.ModuleDict({
            level: nn.Sequential(
                nn.Conv2d(in_channels, out_channels, (1, 1), 1, bias=False),
                nn.GroupNorm(1, out_channels, eps=EPS),
                # nn.ReLU(inplace=True),
            ) for level in return_layers.values()
        })
        self.proj.apply(weight_init)

        self.pos_emb = None
        if with_pos:
            self.pos_emb = Box8PositionEmbedding2D(with_projection=False)

    def forward(self, x, mask=None):
        x = self.fpn(x)

        # smallest feature map (eg. 16x16 for an input of 512x512 pixels)
        _, _, H, W = list(x.values())[-1].size()

        x_out = None
        for level, fmap in x.items():
            # fmap = torch.relu(fmap)  # FPN blocks end in a conv2d, w/o activ.
            if self.pos_emb is not None:
                fmap = torch.cat([fmap, self.pos_emb(fmap)], dim=1)  # +Pos
            fmap = self.proj[level](fmap)   # Conv+BN+ReLU
            fmap = F.interpolate(fmap, (H, W), mode='nearest')  # to a smaller size
            if x_out is None:
                x_out = fmap
            else:
                x_out += fmap

        x_mask = None
        if mask is not None:
            x_mask = F.interpolate(mask, (H, W), mode='bilinear')
            x_mask = (x_mask > 0.5).long()

        return x_out, x_mask


class TransformerImageEncoder(nn.Module):
    def __init__(self,
                 backbone='resnet50', out_channels=256, pretrained=True,
                 freeze_pretrained=False, num_heads=8, num_layers=6,
                 dropout_p=0.1):
        super().__init__()

        model = get_backbone(backbone, pretrained)

        if pretrained and freeze_pretrained:
            for p in model.parameters():
                p.requires_grad = False

        if 'resnet' in backbone:
            self.backbone = detection.backbone_utils.IntermediateLayerGetter(
                model, return_layers=OrderedDict({'layer4': 'output'})
            )
            channels = 512 if backbone in ('resnet18', 'resnet34') else 2048

        elif backbone in ('cspdarknet53', 'efficientnet-b0', 'efficientnet-b3'):
            output_layer_name = list(model.named_children())[-1][0]
            self.backbone = detection.backbone_utils.IntermediateLayerGetter(
                model, return_layers=OrderedDict({output_layer_name: 'output'})
            )
            channels = {
                'cspdarknet53': 1024,
                'efficientnet-b0': 1280,
                'efficientnet-b3': 1536
            }[backbone]

        else:
            raise RuntimeError('not a valid backbone')

        self.proj = nn.Sequential(
            nn.Conv2d(channels, out_channels, (1, 1), 1, bias=False),
            nn.GroupNorm(1, out_channels, eps=EPS),
            # nn.ReLU(inplace=True),
        )
        self.proj.apply(weight_init)

        from transformers_pos import (
            TransformerEncoder,
            TransformerEncoderLayer,
        )

        self.encoder = TransformerEncoder(
            TransformerEncoderLayer(
                d_model=out_channels,
                nhead=num_heads,
                dropout=dropout_p,
                batch_first=True
            ),
            num_layers=num_layers
        )

        self.pos_emb = Box8PositionEmbedding2D(embedding_dim=out_channels)

        self.out_channels = out_channels

    def flatten(self, x):
        N, _, H, W = x.size()
        x = x.to(memory_format=torch.channels_last)
        x = x.permute(0, 2, 3, 1).view(N, H*W, -1)  # NxHWxD
        return x

    def forward(self, img, mask=None):
        x = self.backbone(img)['output']
        x = self.proj(x)  # NxDxHxW

        N, _, H, W = x.size()

        pos = self.pos_emb(x)  # NxDxHxW
        pos = self.flatten(pos)  # NxRxD

        x = self.flatten(x)  # NxRxD

        # visibility mask
        x_mask = None
        if mask is not None:
            x_mask = F.interpolate(mask, (H, W), mode='bilinear')
            x_mask = (x_mask > 0.5).long()

        if mask is None:
            x = self.encoder(x, pos=pos)  # NxRxD
        else:
            mask = self.flatten(x_mask).squeeze(-1)
            x = self.encoder(x, src_key_padding_mask=(mask==0), pos=pos)  # NxRxD

        x = x.permute(0, 2, 1).view(N, -1, H, W)  # NxDxHxW

        return x, x_mask


class LanguageEncoder(nn.Module):
    def __init__(self, out_features=256, dropout_p=0.2,
                 freeze_pretrained=False, global_pooling=True):
        super().__init__()
        self.language_model = transformers.AutoModel.from_pretrained(
            TRANSFORMER_MODEL
        )

        if freeze_pretrained:
            for p in self.language_model.parameters():
                p.requires_grad = False

        self.out_features = out_features

        self.proj = nn.Sequential(
            nn.Linear(768, out_features),
            nn.LayerNorm(out_features, eps=1e-5),
            # nn.ReLU(inplace=True),
            # nn.Dropout(dropout_p),
        )
        self.proj.apply(weight_init)

        self.global_pooling = bool(global_pooling)

    def forward(self, z):
        res = self.language_model(
            input_ids=z['input_ids'],
            position_ids=None,
            attention_mask=z['attention_mask']
        )

        if self.global_pooling:
            z, z_mask = self.proj(res.pooler_output), None
        else:
            z, z_mask = self.proj(res.last_hidden_state), z['attention_mask']

        return z, z_mask


class RNNLanguageEncoder(nn.Module):
    def __init__(self,
                 model_type='gru', hidden_size=1024, num_layers=2,
                 out_features=256, dropout_p=0.2, global_pooling=True):
        super().__init__()
        self.embeddings = transformers.AutoModel.from_pretrained(
            TRANSFORMER_MODEL
        ).embeddings.word_embeddings
        self.embeddings.weight.requires_grad = True

        # self.dropout_emb = nn.Dropout(0.5)
        self.dropout_emb = nn.Dropout(dropout_p)

        assert model_type in ('gru', 'lstm')
        self.rnn = (nn.GRU if model_type == 'gru' else nn.LSTM)(
            input_size=self.embeddings.weight.size(1),
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout_p,
            batch_first=True,
            bidirectional=True
        )

        self.proj = nn.Sequential(
            nn.Linear(2*hidden_size, out_features),
            nn.LayerNorm(out_features, eps=1e-5),
            # nn.ReLU(inplace=True),
            # nn.Dropout(dropout_p),
        )
        self.proj.apply(weight_init)

        self.out_features = out_features

        self.global_pooling = bool(global_pooling)
        assert global_pooling  # only w/ global pooling

    def forward(self, z):
        z_mask = z['attention_mask']

        z = self.dropout_emb(self.embeddings(z['input_ids']))
        z, h_n = self.rnn(z, None)

        if isinstance(self.rnn, nn.LSTM):
            h_n = h_n[0]

        # hidden states as (num_layers, num_directions, batch, hidden_size)
        h_n = h_n.view(self.rnn.num_layers, 2, z.size(0), self.rnn.hidden_size)

        # last hidden states
        h_n = h_n[-1].permute(1, 0, 2).reshape(z.size(0), -1)
        h_n = self.proj(h_n)
        return h_n, z_mask


class SimpleEncoder(nn.Module):
    def __init__(self, out_features=256, dropout_p=0.1, global_pooling=True):
        super().__init__()
        self.embeddings = transformers.AutoModel.from_pretrained(
            TRANSFORMER_MODEL
        ).embeddings.word_embeddings
        self.embeddings.weight.requires_grad = True

        # self.dropout_emb = nn.Dropout(0.5)
        self.dropout_emb = nn.Dropout(dropout_p)

        self.proj = nn.Sequential(
            nn.Linear(768, out_features),
            nn.LayerNorm(out_features, eps=1e-5),
            # nn.ReLU(inplace=True),
            # nn.Dropout(dropout_p),
        )
        self.proj.apply(weight_init)

        self.out_features = out_features

        self.global_pooling = bool(global_pooling)
        assert not self.global_pooling  # only w/o global pooling

    def forward(self, z):
        z_mask = z['attention_mask']
        z = self.embeddings(z['input_ids'])
        z = self.proj(self.dropout_emb(z))
        # z[:, 0] = torch.mean(z[:, 1:], 1)
        return z, z_mask