# Copyright (c) OpenMMLab. All rights reserved.import math
import json
import math
import torch
import torch.nn as nn
from mmengine.model.weight_init import (constant_init, kaiming_init,
                                        trunc_normal_)
from mmengine.model  import ModuleList
from mmengine.runner.checkpoint import _load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm

from ..builder import BACKBONES
from .mae import MAE
from mmengine.model  import BaseModule
import numpy as np
from .lora import wrap_model_with_lora, Linear

def rearrange_activations(activations):
    n_channels = activations.shape[-1]
    activations = activations.reshape(-1, n_channels)
    return activations

def ps_inv(x1, x2):
    '''Least-squares solver given feature maps from two anchors.
    '''
    x1 = rearrange_activations(x1)
    x2 = rearrange_activations(x2)

    if not x1.shape[0] == x2.shape[0]:
        raise ValueError('Spatial size of compared neurons must match when ' \
                         'calculating psuedo inverse matrix.')

    # Get transformation matrix shape
    shape = list(x1.shape)
    shape[-1] += 1

    # Calculate pseudo inverse
    x1_ones = torch.ones(shape)
    x1_ones[:, :-1] = x1
    A_ones = torch.matmul(torch.linalg.pinv(x1_ones), x2.to(x1_ones.device)).T

    # Get weights and bias
    w = A_ones[..., :-1]
    b = A_ones[..., -1]

    return w, b

def reset_out_indices(front_depth=12, end_depth=24, out_indices=(9, 14, 19, 23)):
    block_ids = torch.tensor(list(range(front_depth)))
    block_ids = block_ids[None, None, :].float()
    end_mapping_ids = torch.nn.functional.interpolate(block_ids, end_depth)
    end_mapping_ids = end_mapping_ids.squeeze().long().tolist()

    small_out_indices = []
    for i, idx in enumerate(end_mapping_ids):
        if i in out_indices:
            small_out_indices.append(idx)

    return small_out_indices


def get_stitch_configs_general_unequal(depths):
    depths = sorted(depths)

    total_configs = []

    # anchor configurations
    total_configs.append({'comb_id': [0], })
    total_configs.append({'comb_id': [1], })

    num_stitches = depths[0]
    for i, blk_id in enumerate(range(num_stitches)):
        if i == depths[0] - 1:
            break
        total_configs.append({
            'comb_id': (0, 1),
            'stitch_cfgs': (i, (i + 1) * (depths[1]//depths[0]))
        })
    return total_configs, num_stitches


def get_stitch_configs_bidirection(depths):
    depths = sorted(depths)

    total_configs = []

    # anchor configurations
    total_configs.append({'comb_id': [0], })
    total_configs.append({'comb_id': [1], })

    num_stitches = depths[0]

    # small --> large
    sl_configs = []
    for i, blk_id in enumerate(range(num_stitches)):
        sl_configs.append({
            'comb_id': [0, 1],
            'stitch_cfgs': [
                [i, (i + 1) * (depths[1] // depths[0])]
            ],
            'stitch_layer_ids': [i]
        })

    ls_configs = []
    lsl_confgs = []
    block_ids = torch.tensor(list(range(depths[0])))
    block_ids = block_ids[None, None, :].float()
    end_mapping_ids = torch.nn.functional.interpolate(block_ids, depths[1])
    end_mapping_ids = end_mapping_ids.squeeze().long().tolist()

    # large --> small
    for i in range(depths[1]):
        if depths[1] != depths[0]:
            if i % 2 == 1 and i < (depths[1] - 1):
                ls_configs.append({
                    'comb_id': [1, 0],
                    'stitch_cfgs': [[i, end_mapping_ids[i] + 1]],
                    'stitch_layer_ids': [i // (depths[1] // depths[0])]
                })
        else:
            if i < (depths[1] - 1):
                ls_configs.append({
                    'comb_id': [1, 0],
                    'stitch_cfgs': [[i, end_mapping_ids[i] + 1]],
                    'stitch_layer_ids': [i // (depths[1] // depths[0])]
                })

    # large --> small --> large
    for ls_cfg in ls_configs:
        for sl_cfg in sl_configs:
            if sl_cfg['stitch_layer_ids'][0] == depths[0] - 1:
                continue
            if sl_cfg['stitch_cfgs'][0][0] >= ls_cfg['stitch_cfgs'][0][1]:
                lsl_confgs.append({
                    'comb_id': [1, 0, 1],
                    'stitch_cfgs': [ls_cfg['stitch_cfgs'][0], sl_cfg['stitch_cfgs'][0]],
                    'stitch_layer_ids': ls_cfg['stitch_layer_ids'] + sl_cfg['stitch_layer_ids']
                })

    # small --> large --> small
    sls_configs = []
    for sl_cfg in sl_configs:
        for ls_cfg in ls_configs:
            if ls_cfg['stitch_cfgs'][0][0] >= sl_cfg['stitch_cfgs'][0][1]:
                sls_configs.append({
                    'comb_id': [0, 1, 0],
                    'stitch_cfgs': [sl_cfg['stitch_cfgs'][0], ls_cfg['stitch_cfgs'][0]],
                    'stitch_layer_ids': sl_cfg['stitch_layer_ids'] + ls_cfg['stitch_layer_ids']
                })

    total_configs += sl_configs + ls_configs + lsl_confgs + sls_configs

    anchor_ids = []
    sl_ids = []
    ls_ids = []
    lsl_ids = []
    sls_ids = []

    for i, cfg in enumerate(total_configs):
        comb_id = cfg['comb_id']

        if len(comb_id) == 1:
            anchor_ids.append(i)
            continue

        if len(comb_id) == 2:
            route = []
            front, end = cfg['stitch_cfgs'][0]
            route.append([0, front])
            route.append([end, depths[comb_id[-1]]])
            cfg['route'] = route
            if comb_id == [0, 1] and front != 11:
                sl_ids.append(i)
            elif comb_id == [1, 0]:
                ls_ids.append(i)

        if len(comb_id) == 3:
            route = []
            front_1, end_1 = cfg['stitch_cfgs'][0]
            front_2, end_2 = cfg['stitch_cfgs'][1]
            route.append([0, front_1])
            route.append([end_1, front_2])
            route.append([end_2, depths[comb_id[-1]]])
            cfg['route'] = route

            if comb_id == [1, 0, 1]:
                lsl_ids.append(i)
            elif comb_id == [0, 1, 0]:
                sls_ids.append(i)

        cfg['stitch_layer_ids'].append(-1)

    model_combos = [(0, 1), (1, 0)]
    return total_configs, model_combos, [len(sl_configs), len(ls_configs)], anchor_ids, sl_ids, ls_ids, lsl_ids, sls_ids


def format_out_features(outs, with_cls_token, hw_shape):
    if len(outs[0].shape) == 4:
        for i in range(len(outs)):
            outs[i] = outs[i].permute(0, 3, 1, 2).contiguous()
    else:
        B, _, C = outs[0].shape
        for i in range(len(outs)):
            if with_cls_token:
                # Remove class token and reshape token for decoder head
                outs[i] = outs[i][:, 1:].reshape(B, hw_shape[0], hw_shape[1],
                                                 C).permute(0, 3, 1, 2).contiguous()
            else:
                outs[i] = outs[i].reshape(B, hw_shape[0], hw_shape[1],
                                          C).permute(0, 3, 1, 2).contiguous()
    return outs


def zero_module(module):
    """
    Zero out the parameters of a module and return it.
    """
    for p in module.parameters():
        p.detach().zero_()
    return module

# import loralib as lora


class StitchingLayer(BaseModule):
    def __init__(self, in_features=None, out_features=None, r=0):
        super().__init__()
        self.transform = Linear(in_features, out_features, r)

    def init_stitch_weights_bias(self, weight, bias):
        self.transform.weight.data.copy_(weight)
        self.transform.bias.data.copy_(bias)

    def forward(self, x):
        out = self.transform(x)
        return out


@BACKBONES.register_module()
class SNNetv1(BaseModule):

    def __init__(self, anchors=None):
        super(SNNetv1, self).__init__()
        self.anchors = nn.ModuleList()
        for cfg in anchors:
            mod = MAE(**cfg)
            self.anchors.append(mod)

        self.with_cls_token = self.anchors[0].with_cls_token

        self.depths = [anc.num_layers for anc in self.anchors]

        # reset out indices of small
        self.anchors[0].out_indices = reset_out_indices(self.depths[0], self.depths[1], self.anchors[1].out_indices)

        total_configs, num_stitches = get_stitch_configs_general_unequal(self.depths)
        self.stitch_layers = nn.ModuleList([StitchingLayer(self.anchors[0].embed_dims, self.anchors[1].embed_dims) for _ in range(num_stitches)])

        self.stitch_configs = {i: cfg for i, cfg in enumerate(total_configs)}
        self.all_cfgs = list(self.stitch_configs.keys())
        self.num_configs = len(total_configs)
        self.stitch_config_id = 0

    def reset_stitch_id(self, stitch_config_id):
        self.stitch_config_id = stitch_config_id

    def initialize_stitching_weights(self, x):
        # logger = get_root_logger()
        front, end = 0, 1
        with torch.no_grad():
            front_features = self.anchors[front].extract_block_features(x)
            end_features = self.anchors[end].extract_block_features(x)

        for i, blk_id in enumerate(range(self.depths[0])):
            front_id, end_id = i, (i + 1) * (self.depths[1] // self.depths[0])
            front_blk_feat = front_features[front_id]
            end_blk_feat = end_features[end_id - 1]
            w, b = ps_inv(front_blk_feat, end_blk_feat)
            self.stitch_layers[i].init_stitch_weights_bias(w, b)
            print(f'Initialized Stitching Model {front} to Model {end}, Layer {i}')

    def init_weights(self):
        for anc in self.anchors:
            anc.init_weights()

    def forward(self, x):

        # randomly sample a stitch at each training iteration
        if self.training:
            stitch_cfg_id = np.random.randint(0, self.num_configs)
        else:
            stitch_cfg_id = self.stitch_config_id

        comb_id = self.stitch_configs[stitch_cfg_id]['comb_id']

        if len(comb_id) == 1:
            outs, hw_shape = self.anchors[comb_id[0]](x)
            # in case forwarding the smaller anchor
            if comb_id[0] == 0:
                for i, out_idx in enumerate(self.anchors[comb_id[0]].out_indices):
                    outs[i] = self.stitch_layers[out_idx](outs[i])


        else:
            cfg = self.stitch_configs[stitch_cfg_id]['stitch_cfgs']
            x, outs, hw_shape = self.anchors[comb_id[0]].forward_until(x, blk_id=cfg[0])

            for i, out_idx in enumerate(self.anchors[comb_id[0]].out_indices):
                if out_idx < cfg[0]:
                    outs[i] = self.stitch_layers[out_idx](outs[i])

            x = self.stitch_layers[cfg[0]](x)
            if cfg[0] in self.anchors[comb_id[0]].out_indices:
                outs[-1] = x

            B, _, C = x.shape
            outs_2 = self.anchors[comb_id[1]].forward_from(x, blk_id=cfg[1])
            outs += outs_2

        outs = format_out_features(outs, self.with_cls_token, hw_shape)

        return outs


@BACKBONES.register_module()
class SNNetv2(BaseModule):

    def __init__(self, anchors=None, selected_ids=[], include_sl=True, include_ls=True, include_lsl=True, include_sls=True, lora_r=0, pretrained=None):
        super(SNNetv2, self).__init__()

        self.lora_r = lora_r
        self.anchors = nn.ModuleList()
        for cfg in anchors:
            mod = MAE(**cfg)
            self.anchors.append(mod)

        self.with_cls_token = self.anchors[0].with_cls_token
        self.depths = [anc.num_layers for anc in self.anchors]

        # reset out indices of small
        self.anchors[0].out_indices = reset_out_indices(self.depths[0], self.depths[1], self.anchors[1].out_indices)

        total_configs, model_combos, num_stitches, anchor_ids, sl_ids, ls_ids, lsl_ids, sls_ids = get_stitch_configs_bidirection(self.depths)

        self.stitch_layers = nn.ModuleList()
        self.stitching_map_id = {}

        for i, (comb, num_sth) in enumerate(zip(model_combos, num_stitches)):
            front, end = comb
            temp = nn.ModuleList(
                [StitchingLayer(self.anchors[front].embed_dims, self.anchors[end].embed_dims, lora_r) for _ in range(num_sth)])
            temp.append(nn.Identity())
            self.stitch_layers.append(temp)

        self.stitch_configs = {i: cfg for i, cfg in enumerate(total_configs)}
        self.stitch_init_configs = {i: cfg for i, cfg in enumerate(total_configs) if len(cfg['comb_id']) == 2}

        self.selected_ids = selected_ids
        if len(selected_ids) == 0:
            self.all_cfgs = anchor_ids

            if include_sl:
                self.all_cfgs += sl_ids

            if include_ls:
                self.all_cfgs += ls_ids

            if include_lsl:
                self.all_cfgs += lsl_ids

            if include_sls:
                self.all_cfgs += sls_ids
        else:
            self.all_cfgs = selected_ids


        self.trained_cfgs = {}
        for idx in self.all_cfgs:
            self.trained_cfgs[idx] = self.stitch_configs[idx]

        print(str(self.all_cfgs))
        self.num_configs = len(self.stitch_configs)
        self.stitch_config_id = 0

    def reset_stitch_id(self, stitch_config_id):
        self.stitch_config_id = stitch_config_id

    def initialize_stitching_weights(self, x):
        anchor_features = []
        for anchor in self.anchors:
            with torch.no_grad():
                temp = anchor.extract_block_features(x)
                anchor_features.append(temp)

        for idx, cfg in self.stitch_init_configs.items():
            comb_id = cfg['comb_id']
            if len(comb_id) == 2:
                front_id, end_id = cfg['stitch_cfgs'][0]
                stitch_layer_id = cfg['stitch_layer_ids'][0]
                front_blk_feat = anchor_features[comb_id[0]][front_id]
                end_blk_feat = anchor_features[comb_id[1]][end_id - 1]
                w, b = ps_inv(front_blk_feat, end_blk_feat)
                self.stitch_layers[comb_id[0]][stitch_layer_id].init_stitch_weights_bias(w, b)
                print(f'Initialized Stitching Layer {cfg}')

    def resize_abs_pos_embed(self, state_dict):
        pos_keys = [k for k in state_dict.keys() if 'pos_embed' in k]

        for pos_k in pos_keys:
            anchor_id = int(pos_k.split('.')[1])
        # if 'pos_embed' in state_dict:
            pos_embed_checkpoint = state_dict[pos_k]
            embedding_size = pos_embed_checkpoint.shape[-1]
            num_extra_tokens = self.anchors[anchor_id].pos_embed.shape[-2] - self.anchors[anchor_id].num_patches
            # height (== width) for the checkpoint position embedding
            orig_size = int(
                (pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
            # height (== width) for the new position embedding
            new_size = int(self.anchors[anchor_id].num_patches**0.5)
            # class_token and dist_token are kept unchanged
            if orig_size != new_size:
                extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
                # only the position tokens are interpolated
                pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
                pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
                                                embedding_size).permute(
                                                    0, 3, 1, 2)
                pos_tokens = torch.nn.functional.interpolate(
                    pos_tokens,
                    size=(new_size, new_size),
                    mode=self.anchors[anchor_id].interpolate_mode,
                    align_corners=False)
                pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
                new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
                state_dict[pos_k] = new_pos_embed
        return state_dict

    def init_weights(self):
        for anc in self.anchors:
            anc.init_weights()
    def sampling_stitch_config(self):
        flops_id = np.random.choice(len(self.flops_grouped_cfgs))
        self.stitch_config_id = np.random.choice(self.flops_grouped_cfgs[flops_id])

    def get_stitch_parameters(self):
        stitch_cfg_id = self.stitch_config_id

        comb_id = self.stitch_configs[stitch_cfg_id]['comb_id']

        total_params = 0

        # forward by a single anchor
        if len(comb_id) == 1:
            total_params += sum(p.numel() for p in self.anchors[comb_id[0]].parameters())
            # outs, hw_shape = self.anchors[comb_id[0]](x)
            # in case forwarding the smaller anchor
            if comb_id[0] == 0:
                for i, out_idx in enumerate(self.anchors[comb_id[0]].out_indices):
                    total_params += sum([p.numel() for p in self.stitch_layers[0][out_idx].parameters()])

            return total_params

        # forward among anchors
        route = self.stitch_configs[stitch_cfg_id]['route']
        stitch_layer_ids = self.stitch_configs[stitch_cfg_id]['stitch_layer_ids']

        # patch embeding
        total_params += self.anchors[comb_id[0]].patch_embed_params()

        for i, (model_id, cfg) in enumerate(zip(comb_id, route)):
            total_params += self.anchors[model_id].selective_params(cfg[0], cfg[1])

            if model_id == 0:
                mapping_idx = [idx for idx in self.anchors[model_id].out_indices if cfg[0] <= idx <= cfg[1]]
                for j, out_idx in enumerate(mapping_idx):
                    total_params += sum([p.numel() for p in self.stitch_layers[model_id][out_idx].parameters()])

            total_params += sum([p.numel() for p in self.stitch_layers[model_id][stitch_layer_ids[i]].parameters()])

        return total_params

    def forward(self, x):

        if self.training:
            self.sampling_stitch_config()

        stitch_cfg_id = self.stitch_config_id

        comb_id = self.stitch_configs[stitch_cfg_id]['comb_id']

        # forward by a single anchor
        if len(comb_id) == 1:
            outs, hw_shape = self.anchors[comb_id[0]](x)
            # in case forwarding the smaller anchor
            if comb_id[0] == 0:
                for i, out_idx in enumerate(self.anchors[comb_id[0]].out_indices):
                    outs[i] = self.stitch_layers[0][out_idx](outs[i])
            outs = format_out_features(outs, self.with_cls_token, hw_shape)
            return outs

        # forward among anchors
        route = self.stitch_configs[stitch_cfg_id]['route']
        stitch_layer_ids = self.stitch_configs[stitch_cfg_id]['stitch_layer_ids']

        # patch embeding
        x, hw_shape = self.anchors[comb_id[0]].forward_patch_embed(x)
        final_outs = []

        for i, (model_id, cfg) in enumerate(zip(comb_id, route)):
            x, outs = self.anchors[model_id].selective_forward(x, cfg[0], cfg[1])
            if model_id == 0:
                mapping_idx = [idx for idx in self.anchors[model_id].out_indices if cfg[0] <= idx <= cfg[1]]
                for j, out_idx in enumerate(mapping_idx):
                    outs[j] = self.stitch_layers[model_id][out_idx](outs[j])

            final_outs += outs
            x = self.stitch_layers[model_id][stitch_layer_ids[i]](x)

        final_outs = format_out_features(final_outs, self.with_cls_token, hw_shape)
        return final_outs