import copy
from doctest import ELLIPSIS_MARKER
from functools import partial
import json
from turtle import forward, shape
import einops
import torch
from torch import nn

from mmcls.models.backbones.vision_transformer import TransformerEncoderLayer
from transformers import GPT2Model, GPT2Config,GPT2LMHeadModel,GPTNeoForCausalLM,GPTNeoModel, \
 BartModel, BartConfig, BartForCausalLM, BertForMaskedLM, AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokenizer 
from transformers import BitsAndBytesConfig

from peft import prepare_model_for_kbit_training
from peft import LoraConfig
from peft import get_peft_model

        
from mmcv.cnn import build_norm_layer
from mmcv.runner import BaseModule
import math
from ipdb import set_trace

class mixEmbed(nn.Module):
    def __init__(self, lm_embed: nn.Embedding , audio_embeddings, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.lm_embed = lm_embed
        self.audio_embeddings = audio_embeddings # ugly but works without modifying raw model codes
        
    def forward(self, input_ids):
        text_ids = torch.clamp(input_ids.clone(), 0).long()
        
        au_ids = torch.clamp(-(input_ids.clone() + 1), 0).long()
        text_embeds = self.lm_embed(text_ids)
        au_embeds = self.audio_embeddings[au_ids]
        with torch.no_grad():
            embed_mask = (input_ids > 0)
        mix_embeds = au_embeds.clone()
        mix_embeds[embed_mask] = text_embeds[embed_mask]
        return mix_embeds
 

class LMDecoder(nn.Module):
    def __init__(self,
                # num_patches=196,
                img_size=(80,512),
                patch_size:int=16,
                in_chans:int=3,
                embed_dim=1024, # encoder embed dim
                decoder_embed_dim=512,
                norm_cfg=dict(type='LN', eps=1e-6),
                # patch_resolution=14,
                decoder_type='gpt2',
                freeze_decoder=True,
                additional_layer:int=0,
                ):
        super().__init__()
        self.decoder_type = decoder_type
        self.load_lm()
        
        self.lm_embed = self.lm.get_input_embeddings()
        try:
            self.lm_pos_embed = self.lm.get_position_embeddings()
        except NotImplementedError:
            self.lm_pos_embed = None # rotrary embeds
            
        
        if hasattr(self.lm,'embed_dim'):
            self.embed_dim = self.lm.embed_dim
        else:
            self.embed_dim = decoder_embed_dim
        
        # self.asLM = asLM # if generates tokens rather than hidden states
        # if self.asLM: # TODO: 当年写这个是为啥?
        #     self.lm.set_output_embeddings(nn.Linear(self.embed_dim, self.self.LMconfig.vocab_size, bias=False))
        self.freeze_decoder = False
        if True:
            for para in self.lm.parameters():
                para.requires_grad = False
        
    def load_lm(self):
        ## ---------------------LM setting----------------------
        self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, token='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')
        self.lm = AutoModelForCausalLM.from_pretrained(self.decoder_type, token='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')
       
        
    def forward(self, input_ids, flatten_embs, attention_mask, labels, **kwargs):
        mix_embed = mixEmbed(self.lm_embed, flatten_embs)
        self.lm.set_input_embeddings(mix_embed) # modification of the lm embed 
        output = self.lm(input_ids=input_ids, attention_mask=attention_mask, labels=labels, output_hidden_states=True, **kwargs) 
        self.lm.set_input_embeddings(self.lm_embed) # modification of the lm embed 
        return output

    def generate(self, input_ids, flatten_embs):
        mix_embed = mixEmbed(self.lm_embed, flatten_embs)
        self.lm.set_input_embeddings(mix_embed) # modification of the lm embed 
        outputs = self.lm.generate(input_ids=input_ids, max_new_tokens=256, use_cache=False)
        # outputs = self.lm.generate(input_ids=input_ids, 
        #                            max_new_tokens=1024, 
        #                            do_sample=True,
        #                            temperature=1.5,
        #                            num_beams=1,
        #                            top_p=0.9,
        #                            top_k=3,
        #                            use_cache=False)
        self.lm.set_input_embeddings(self.lm_embed) # modification of the lm embed 
        return outputs
'''
## infer params
max_input_tokens: 40
batch_size_test: 16
max_new_tokens: 64
min_length: 2
num_beams: 5
length_penalty: -2.0
top_p: 0.9
top_k: 3
no_repeat_ngram_size: 2
apply_lemmatizer: False
use_nucleus_sampling: True
'''

class LMDecoder_qlora(LMDecoder):
    def __init__(self,
                # num_patches=196,
                img_size=(80,512),
                patch_size:int=16,
                in_chans:int=3,
                embed_dim=1024, # encoder embed dim
                decoder_embed_dim=512,
                norm_cfg=dict(type='LN', eps=1e-6),
                # patch_resolution=14,
                decoder_type='gpt2',
                freeze_decoder=True,
                additional_layer:int=0,
                ):
        super().__init__( img_size, patch_size, in_chans, embed_dim, decoder_embed_dim, norm_cfg, decoder_type, freeze_decoder, additional_layer)
        
    def load_lm(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.decoder_type)
        self.LMconfig = AutoConfig.from_pretrained(self.decoder_type, trust_remote_code=True )
        double_quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            )
        model = AutoModelForCausalLM.from_pretrained(self.decoder_type, 
                                                    #  device_map='auto', # if remove, can not add lora 
                                                    # load_in_4bit=True,# if remove, can not add lora 
                                                    # # torch_dtype=torch.bfloat16,
                                                    #  quantization_config=double_quant_config, # if remove, can not add lora 
                                                     trust_remote_code=True )

        model.gradient_checkpointing_enable()
        model = prepare_model_for_kbit_training(model)
        lora_config = LoraConfig(
            r=8, 
            lora_alpha=32, 
            target_modules=["query_key_value"], 
            lora_dropout=0.05, 
            bias="none", 
            task_type="CAUSAL_LM"
        )

        self.lm = get_peft_model(model, lora_config)
        self.lm.print_trainable_parameters()