import json
import os
import pdb
from mmcv.cnn.bricks import padding
import torch
from torch import nn, einsum
from typing import Optional, Dict, Tuple
from .mae_vit import MAEViT
from .htsat import HTSAT_Swin_Transformer, create_htsat_model
from .LMdecoder import LMDecoder, LMDecoder_qlora
from .vision_transformer import VisionTransformer
from einops import rearrange, repeat
from einops_exts import rearrange_many
import inspect

from transformers.modeling_utils import PreTrainedModel
from .configuration_maelm import MAELMConfig

class ArgsHandler:
    def __init__(self, module, funcname, fargs, fkargs):
        self.fargs = list(fargs)
        self.fkargs = fkargs
        func = getattr(module, funcname)
        fal_repr = f"{funcname}_argnames_list"
        if (argns_list:=getattr(module, fal_repr, None)) is None:
            self.func_sig = inspect.signature(func)
            self.argnames_list = list(self.func_sig.parameters.keys())
            setattr(module, fal_repr, self.argnames_list)
        else:
            self.argnames_list = argns_list

    def get_arg(self, arg_name):
        if arg_name in self.fkargs:
            arg = self.fkargs[arg_name]
        else:
            arg = self.fargs[self.argnames_list.index(arg_name)]
        return arg

    def set_arg(self, arg_name, arg_value):
        if arg_name in self.fkargs:
            self.fkargs[arg_name] = arg_value
        else:
            self.fargs[self.argnames_list.index(arg_name)] = arg_value

    def return_all_args(self,):
        return tuple(self.fargs), self.fkargs

class SquaredReLU(nn.Module):
    """ squared ReLU activation function"""
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return torch.pow(torch.relu(x), 2)

def FeedForward(dim, out_dim, mult=4, act='gelu'):
    """
    lucidrains implementation, slightly modified with the act parameter.
    """

    acts = dict(
        gelu=nn.GELU,
        sqrelu=SquaredReLU,
        relu=nn.ReLU
    )

    assert act in acts, f"act. can only be one of {acts.keys()}"

    inner_dim = int(dim * mult)
    return nn.Sequential(
        nn.LayerNorm(dim),
        nn.Linear(dim, inner_dim, bias=False),
        acts[act](),
        nn.Linear(inner_dim, out_dim, bias=False)
    )


class PerceiverAttentionLayer(nn.Module):
    def __init__(
            self,
            *,
            feat_dim,
            latent_dim,
            dim_head=64,
            heads=8
        ):
        super().__init__()
        self.scale = dim_head ** -0.5
        self.heads = heads
        self.dim_head = dim_head

        inner_dim = dim_head * heads

        # trainable components of PerceiverAttentionLayer
        self.norm_media = nn.LayerNorm(feat_dim)
        self.norm_latents = nn.LayerNorm(latent_dim)

        self.to_q = nn.Linear(latent_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(feat_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(feat_dim, inner_dim, bias=False)
        self.to_out = nn.Linear(inner_dim, latent_dim, bias=False)

    def forward(self, features, latents):
        """
        Latent vectors are cross-attending to the visual features x.
        :param x:       Tensor (n_batch, n_features, dim)
                        visual features
        :param latents: Tensor (n_batch, n_latents, dim)
                        latent learnt vectors from which the queries are computed.
                        Actually the same, just replicated in n_batch and n_frames dimension.
        :return:        Tensor (n_batch, n_latents, dim)
        """
        assert features.ndim == 3
        assert latents.ndim == 3
        assert features.shape[0] == latents.shape[0]
        #assert features.shape[2] == latents.shape[2]

        n_heads = self.heads
        n_batch, n_features, dim = features.shape
        n_queries = latents.shape[1]

        # layer normalization, as usual
        x = self.norm_media(features)
        latents = self.norm_latents(latents)

        # queries
        # compute the queries from the latents, for all attention heads simultaneously.
        q = self.to_q(latents)
        q = rearrange(q, 'b q (h d) -> b h q d', h=n_heads)
        assert q.shape == torch.Size([n_batch, n_heads, n_queries, self.dim_head])

        # keys and values for all attention heads
            
        '''
        kv_input = torch.cat((x, latents), dim=-2)
        n_features_latents = n_features + n_queries
        '''

        kv_input = x
        n_features_latents = n_features

        # keys, values
        k = self.to_k(kv_input)
        v = self.to_v(kv_input)
        # batch, features, (heads, dim)

        # split so we have an extra dimension for the heads
        # q, k, v = rearrange_many((q, k, v), 'b t n (h d) -> b h t n d', h=h)
        k, v = rearrange_many((k, v), 'b f (h d) -> b h f d', h=n_heads)
        assert v.shape == torch.Size([n_batch, n_heads, n_features_latents, self.dim_head])

        # scale queries?
        q = q * self.scale

        # attention

        # attention scores
        # sim = einsum('... i d, ... j d  -> ... i j', q, k)
        sim = einsum('b h q d, b h f d -> b h q f', q, k)

        # Is this for numerical stability? Does not affect the result of the softmax operation
        sim = sim - sim.amax(dim=-1, keepdim=True).detach()
        alphas = sim.softmax(dim=-1)

        # out = einsum('... i j, ... j d -> ... i d', alphas, v)
        out = einsum('b h q f, b h f v -> b h q v', alphas, v)

        # out = rearrange(out, 'b h t n d -> b t n (h d)', h=h)
        out = rearrange(out, 'b h q v -> b q (h v)')
        return self.to_out(out)


class MAEForCausalLM(PreTrainedModel):
    """

    Args:
        backbone (dict): Config dict for encoder. Defaults to None.
        neck (dict): Config dict for encoder. Defaults to None.
        head (dict): Config dict for loss functions. Defaults to None.
        init_cfg (dict, optional): Config dict for weight initialization.
            Defaults to None.
    """
    
    config_class = MAELMConfig

    def __init__(self, config: MAELMConfig) -> None:
        super().__init__(config)
        backbone = config.backbone
        assert backbone is not None
        bk_name = backbone.pop('name')
        self.bk_name = bk_name
        if bk_name == 'MAEViT':
            ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
            self.backbone = MAEViT(**backbone)
            #if ckpt_path is not None:
            #    ckpt = torch.load( ckpt_path,'cpu')
            #    self.backbone.load_state_dict(ckpt['state_dict'])
                
        elif bk_name == 'HTSAT':
            ckpt_path = backbone.pop('ckpt') if 'ckpt' in backbone else None
            self.backbone = create_htsat_model(backbone)
            if ckpt_path is not None:
                ckpt = torch.load( ckpt_path,'cpu')
                self.backbone.load_state_dict(ckpt['state_dict'])
        elif bk_name == 'qformer':
            raise NotImplemented        
        else:
            raise NotImplemented



        # neck["num_patches"] = self.backbone.num_patches
        # neck["patch_resolution"] = self.backbone.patch_resolution
        neck = config.neck
        assert neck is not None
        nk_name = neck.pop('name')
        if nk_name == 'LMDecoder':
            self.neck = LMDecoder(**neck)
        elif nk_name == 'LMDecoder_qlora':
            self.neck = LMDecoder_qlora(**neck)
        else: 
            raise NotImplemented
        self.config = self.neck.LMconfig # TODO

        '''
        self.ae_proj = nn.Linear(
            768,  self.config.hidden_size
        )
        '''
        
        ## TODO

        #self.neck.lm.apply(lambda m:m.gradient_checkpointing=True)
        self.neck.lm.model.gradient_checkpointing = False

        self.register_buffer('ones', torch.ones((1,4096), dtype=torch.long), persistent=False)
        self.graft_adapter()
        self.init_weights()
        # float32 --> bfloat16
        for p in self.parameters():
            p.data = p.data.to(torch.bfloat16)
        #if config.resume_from_checkpoint is not None:  
        #    drain_loader = True
        #    accelerator.load_state(config.resume_from_checkpoint, load_module_strict=False)
        #    # start_epoch, start_step, all_step = [int(_.split('_')[1]) for _ in args.resume_from_checkpoint.split('/')[-2].split('-')]
        #elif config.resume_from_pth is not None:
        #    print(f'###########loading##########{config.resume_from_pth}###########loading##########')
        #    ckpt = torch.load(config.resume_from_pth, map_location='cpu')
        #    ckpt_copy = {k[7:]: v for k, v in ckpt.items()}
        #    self.load_state_dict(ckpt_copy, strict=False)
        #    print(f'###########loaded##########{config.resume_from_pth}###########loaded##########')

        if False:
            self.patch_llm()
        self.first_run = True
    
    def graft_adapter(self):
        adapter_latent_len = 32
        self.adapter_latent_len = adapter_latent_len
        self.adapter_latent = nn.Parameter(torch.rand((1,adapter_latent_len, self.config.hidden_size), \
                                                     dtype=torch.float))
        resampler_latent_len = 32
        self.resampler_latent_len = resampler_latent_len
        self.resampler_latent = nn.Parameter(torch.rand((1,resampler_latent_len, self.config.hidden_size), \
                                                     dtype=torch.float))
        ## TODO
        # self.adapter.pre_bn = torch.nn.BatchNorm1d(4096, affine=True)

        self.adapter = nn.ModuleList([])
        
        ff_mult = 4
        heads=8
        dim_head=512
        act='gelu'

        lm_dim = self.config.hidden_size
        if self.bk_name == 'HTSAT':
            feat_dim = 1024
            depth = len(self.backbone.layers[2].blocks)
        else:
            feat_dim = 768
            depth = int(len(self.neck.lm.model.layers)/2) # 16
        for idx in range(depth):
            self.adapter.append(nn.ModuleList([
                Adapter(input_size=self.config.hidden_size),
                # PerceiverAttentionLayer(feat_dim=feat_dim, latent_dim=lm_dim, dim_head=dim_head, heads=heads),
                # FeedForward(dim=lm_dim, out_dim=lm_dim, mult=1, act=act),
                #FeedForward(dim=self.dim, out_dim=768, mult=ff_mult, act=act) if idx != depth-1 else nn.Identity()
            ])) 

        self.samplers = nn.ModuleList([]) # add
        for _ in range(3):
            self.samplers.append(nn.ModuleList([
                PerceiverAttentionLayer(feat_dim=feat_dim, latent_dim=lm_dim, dim_head=64, heads=heads),
                FeedForward(dim=lm_dim, out_dim=lm_dim, mult=4),
            ]))
        self.norm = nn.LayerNorm(lm_dim)

        # self.agate_list = nn.ParameterList([])
        # for i in range(len(self.neck.lm.model.layers)):
        #     self.agate_list.append(nn.Parameter(torch.zeros(lm_dim)))


        
    def init_weights(self):
        try:
            super().init_weights()
        except:
            pass
            # import traceback
            # traceback.print_exc()
        if getattr(self, 'adapter_latent', None) is not None:
            self.adapter_latent.data.normal_(mean=0.0, std=0.02)
        if getattr(self, 'resampler_latent', None) is not None:
            self.adapter_latent.data.normal_(mean=0.0, std=0.02)

    def forward_resampler(self, x):
        # b, 768, 512
        latents = repeat(self.resampler_latent, 'b n d -> (bs b) n d', bs=x.shape[0])
        for attn, ff in self.samplers:
            latents = attn(x, latents) + latents
            latents = ff(latents) + latents
        v2t_feats = self.norm(latents) # 
        # v2t_atts = torch.ones(v2t_feats.shape[:2], dtype=torch.long, device=v2t_feats.device)
        return v2t_feats # bs, 32, dim_llm


    def hook_adapter(self, audio_embedding, lm, v2t_feats):
        
        class PHooker:
            # model = self.backbone
            # mgtr = self.backbone.forward_generator(spectrogram)
            adapter = self.adapter
            y = v2t_feats
            handles_list = list()
            cnter = 0
            def layer_prehook(self, m, margs, mkargs):
                ahl = ArgsHandler(m, 'forward', margs, mkargs)
                
                # print(self.cnter)
                
                # if self.cnter>=16:
                #     self.cnter+=1
                #     return None
                adapt = self.adapter[self.cnter][0]

                hs = ahl.get_arg("hidden_states")
                adapter_residual = hs
                neo_hs = adapt(hs, adapter_residual)

                self.cnter+=1
                ahl.set_arg("hidden_states", neo_hs)
                return ahl.return_all_args()
            def first_layer_prehook(self, m, margs, mkargs):
                ahl = ArgsHandler(m, 'forward', margs, mkargs)
                neo_lm_latents = self.y #  torch.Size([128, 32, 4096])
                hs = ahl.get_arg("hidden_states") # torch.Size([128, 87, 4096])
                hs_msk = self.lm_ahl.get_arg("input_ids") < 0 # torch.Size([128, 87]) [False,, True*32, False,,]
                # __import__('pdb').set_trace()
                neo_hs = hs.masked_scatter(hs_msk.unsqueeze(-1), neo_lm_latents)  # resampler hooker直接替换
                ahl.set_arg("hidden_states", neo_hs)
                return ahl.return_all_args()

            def lm_prehook(self, m, margs, mkargs):
                self.lm_ahl = ArgsHandler(m, 'forward', margs, mkargs)
                return None
            def last_layer_hook(self, m, margs, mkargs):
                # __import__('pdb').set_trace()
                self.cnter = 0

        if getattr(lm,'phooker',False):
            for _ in lm.phooker.handles_list:
                _.remove()
            del lm.phooker
            lm.phooker = None
        phooker = PHooker()
        phooker.handles_list.append(lm.register_forward_pre_hook(phooker.lm_prehook, with_kwargs=True))
        # 第一层插入
        phooker.handles_list.append(lm.model.layers[0].register_forward_pre_hook(phooker.first_layer_prehook, with_kwargs=True))
       
        for ii in range(1,len(lm.model.layers),2):
            l = lm.model.layers[ii]
            handle = l.register_forward_pre_hook(phooker.layer_prehook, with_kwargs=True)
            phooker.handles_list.append(handle)
        phooker.handles_list.append(lm.model.layers[-1].register_forward_pre_hook(phooker.last_layer_hook, with_kwargs=True))
        lm.phooker = phooker 
        return None



    def prepare_ids(self, batch, audio_ids):
        toker = self.neck.tokenizer
        # for idx, l in enumerate(self.neck.lm.model.layers):
        #     l.agate = self.agate_list[idx].clone() ## should clone the parameter
         
        with torch.no_grad():
            
            input_ids = batch['input_ids']
            att_msk = batch['attention_mask']
            au_crds = batch['audio_crds']
            ans_crds = batch['ans_crds']
            bsz = input_ids.shape[0]
            # __import__('pdb').set_trace()
            ## TODO
            merged_ids, merged_msk, label_ids = list(), list(), list()
            for i in range(bsz):
                # cur_merged_ids = torch.cat([input_ids[i,:au_crds[i]], -1 * audio_ids[i] -1, input_ids[i,au_crds[i]:]])
                cur_merged_ids = torch.cat([ -1 * audio_ids[i] -1, input_ids[i,au_crds[i]:]])
                
                # cur_au_msk = self.ones[:,:audio_ids.shape[1]][0].clone().type_as(att_msk).detach()
                cur_au_msk = torch.ones(audio_ids.shape[1], device=audio_ids.device)
                # cur_merged_msk = torch.cat([att_msk[i,:au_crds[i]], cur_au_msk, att_msk[i,au_crds[i]:]])
                cur_merged_msk = torch.cat([ cur_au_msk, att_msk[i,au_crds[i]:]])
                cur_label_ids = cur_merged_ids.clone().detach()
                cur_label_ids[:audio_ids.shape[1]+ans_crds[i]] = -100

                merged_ids.append(cur_merged_ids)
                merged_msk.append(cur_merged_msk)
                label_ids.append(cur_label_ids)

            merged_ids = torch.stack(merged_ids, dim=0) 
            merged_msk = torch.stack(merged_msk, dim=0) 
            label_ids = torch.stack(label_ids, dim=0) 

            assert merged_ids.shape[0] == bsz
            assert merged_ids.shape == merged_msk.shape

            label_msk = merged_msk.clone()
            assert label_msk.shape == merged_msk.shape
            assert merged_msk[:,-1].max() == 1

            for i in range(len(ans_crds)):
                label_ids[i,:audio_ids.shape[1]+ans_crds[i]].fill_(-100)
            
             
            merged_labels = label_ids
            merged_ids[merged_ids.eq(-100)] = toker.pad_token_id

        return merged_ids, merged_msk, merged_labels

    def forward(self, batch, **kwargs):
        """Forward computation during training.

        Args:
            img (torch.Tensor): Input images of shape (N, C, H, W).
            kwargs: Any keyword arguments to be used to forward.
        Returns:
            Dict[str, torch.Tensor]: A dictionary of loss components.
        """
        bsz = len(batch['input_ids'])
        device = batch['input_ids'].device
        float_type = next(self.parameters()).dtype
        spectrogram = batch['spectrogram'].type(float_type)
        audio_embedding = self.backbone(spectrogram).detach() # b, 768, 512
        resampler_feats = self.forward_resampler(audio_embedding)
        self.hook_adapter(audio_embedding, self.neck.lm, resampler_feats) # add hook
        
        # self.hook_resapmler(resampler_feats, self.neck.lm)
        
        audio_ids = torch.arange(self.adapter_latent.shape[1]).unsqueeze(0).repeat((bsz, 1)).long().to(device)
        assert audio_ids.max() < 100
        merged_ids, merged_msk, merged_labels = self.prepare_ids(batch, audio_ids)
        
        try:
            assert merged_ids.shape == merged_labels.shape
            outs = self.neck(input_ids=merged_ids.contiguous().long(),
                    flatten_embs=self.adapter_latent.flatten(0,1), # 32, 4096
                    # flatten_embs = resampler_feats.flatten(0,1), # b, 32, 4096
                    attention_mask=merged_msk.contiguous().long(), 
                    labels=merged_labels.contiguous().long(), use_cache=False)
        except Exception as e:
            import traceback
            traceback.print_exc()
            __import__('remote_pdb').set_trace()
        #outs.hidden_logits = self.hidden_logits

        ## TODO
        if eval(os.environ.get("doing_eval", 'False')):
            outs.merged_ids = merged_ids.cpu()
            outs.merged_labels = merged_labels.cpu()

        return outs


    def forward_test(self, batch, **kwargs):
        """Forward computation during training.

        Args:
            img (torch.Tensor): Input images of shape (N, C, H, W).
            kwargs: Any keyword arguments to be used to forward.
        Returns:
            Dict[str, torch.Tensor]: A dictionary of loss components.
        """


        bsz = len(batch['input_ids'])
        device = batch['input_ids'].device
        float_type = next(self.parameters()).dtype
        spectrogram = batch['spectrogram'].type(float_type)
        audio_embedding = self.backbone(spectrogram).detach() # b, 768, 512
        resampler_feats = self.forward_resampler(audio_embedding)
        self.hook_adapter(audio_embedding, self.neck.lm, resampler_feats) # add hook
        # self.extract_features(batch, self.neck.lm)
        audio_ids = torch.arange(self.adapter_latent.shape[1]).unsqueeze(0).repeat((bsz, 1)).long().to(device)
        assert audio_ids.max() < 100

        merged_ids, merged_msk, merged_labels = self.prepare_ids(batch, audio_ids)
        au_crds = batch['audio_crds']
        ans_crds = batch['ans_crds']
        
        aid_len = audio_ids.shape[-1]
        

        toker = self.neck.tokenizer
        with torch.no_grad():

            ## TODO
            pad_token = toker.encode(self.neck.tokenizer.eos_token)[0]
            padded_merged_ids = self.ones[:, :aid_len+max(ans_crds)].repeat(bsz, 1).clone().detach() * pad_token
            for i in range(bsz):
            # for i in range(1):
                assert au_crds[i] <= ans_crds[i]
                cur_ids = merged_ids[i][:aid_len+ans_crds[i]]
                padded_merged_ids[i][max(ans_crds)-ans_crds[i]:] = cur_ids
        # __import__('pdb').set_trace()
        outs = self.neck.generate(padded_merged_ids, self.adapter_latent.flatten(0,1))
        #outs.hidden_logits = self.hidden_logits

        return outs



import torch
from torch import nn

from transformers.activations import ACT2FN

class Adapter(nn.Module):
    """
    Implementation of a sequential bottleneck adapter block.
    """
    def __init__(
        self,
        input_size,
        down_sample=None,
    ):
        super().__init__()

        self.input_size = input_size

        # if a downsample size is not passed, we just half the size of the original input
        self.down_sample = down_sample
        if down_sample is None:
            self.down_sample = self.input_size // 2

        self.adapter_norm_before = nn.LayerNorm(self.input_size)
        self.adapter_down = nn.Linear(self.input_size, self.down_sample)
        self.non_linearity = ACT2FN["silu"]

        # Up projection to input size
        self.adapter_up = nn.Linear(self.down_sample, self.input_size)

        # Additional scaling factor (from He et al. (2021))
        self.scaling = nn.Parameter(torch.ones(1))   

        self.adapter_down.apply(self._init_weights)
        self.adapter_up.apply(self._init_weights)

    def forward(self, x, residual_input):  # , residual_input=None):

        down = self.non_linearity(self.adapter_down(self.adapter_norm_before(x)))

        up = self.adapter_up(down)
        up = up * self.scaling
        output = up

        output = output + residual_input

        return output

    @staticmethod
    def _init_weights(module):
        """Initialize the weights."""
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # std defaults to 0.02, this might need to be changed
            module.weight.data.normal_(mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()