""" PyTorch xTrimoPGLM model. """

import math
import copy
import warnings
import re
import sys
import os
import pathlib
import time
import random
import numpy as np
from tqdm.auto import tqdm

import torch, deepspeed
import torch.utils.checkpoint
import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss, LayerNorm, MSELoss, BCEWithLogitsLoss
from torch.nn.utils import skip_init
from typing import Optional, Tuple, Union, List, Callable, Dict, Any
from copy import deepcopy
from collections import namedtuple

from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    MaskedLMOutput,
    CausalLMOutputWithPast,
    SequenceClassifierOutput,
    TokenClassifierOutput
)
from transformers import PreTrainedModel
from transformers.utils import logging
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput

from .configuration_xtrimopglm import xTrimoPGLMConfig
from .quantization import quantize

def get_checkpoint_fn():
    if deepspeed.checkpointing.is_configured():
        checkpoint = deepspeed.checkpointing.checkpoint
    else:
        checkpoint = torch.utils.checkpoint.checkpoint
    return checkpoint

# flags required to enable jit fusion kernels

if sys.platform != 'darwin':
    torch._C._jit_set_profiling_mode(False)
    torch._C._jit_set_profiling_executor(False)
    torch._C._jit_override_can_fuse_on_cpu(True)
    torch._C._jit_override_can_fuse_on_gpu(True)

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "BioMap/xtrimopglm-100b-int4"
_CONFIG_FOR_DOC = "xTrimoPGLMConfig"
DeepNormCoefficients = namedtuple("DeepNormCoefficients", ["alpha", "beta"])

def default_init(cls, *args, **kwargs):
    return cls(*args, **kwargs)


def get_deepnorm_coefficients(config: xTrimoPGLMConfig):
    """
        DeepNorm coefficients from : https://kexue.fm/archives/8978
    """
    num_layers = config.num_layers
    return DeepNormCoefficients(alpha=(2 * num_layers) ** 0.5, beta=(2 * num_layers) ** -0.5)


class InvalidScoreLogitsProcessor(LogitsProcessor):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if torch.isnan(scores).any() or torch.isinf(scores).any():
            scores.zero_()
            scores[..., 5] = 5e4
        return scores


def split_tensor_along_last_dim(
        tensor: torch.Tensor,
        num_partitions: int,
        contiguous_split_chunks: bool = False,
) -> List[torch.Tensor]:
    """Split a tensor along its last dimension.

    Arguments:
        tensor: input tensor.
        num_partitions: number of partitions to split the tensor
        contiguous_split_chunks: If True, make each chunk contiguous
                                 in memory.

    Returns:
        A list of Tensors
    """
    # Get the size and dimension.
    last_dim = tensor.dim() - 1
    last_dim_size = tensor.size()[last_dim] // num_partitions
    # Split.
    tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
    # Note: torch.split does not create contiguous tensors by default.
    if contiguous_split_chunks:
        return tuple(chunk.contiguous() for chunk in tensor_list)

    return tensor_list

class RotaryEmbedding(torch.nn.Module):
    
    def __init__(self, dim, base=10000, precision=torch.half, learnable=False):
        super().__init__()
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)).to(precision)
        self.dim = dim
        self.base = base
        self.learnable = learnable
        if learnable:
            self.inv_freq = torch.nn.Parameter(inv_freq)
            self.max_seq_len_cached = None
        else:
            self.register_buffer('inv_freq', inv_freq)
            self.max_seq_len_cached = None
            self.cos_cached = None
            self.sin_cached = None
        self.precision = precision
    
    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
        if f'{prefix}inv_freq' in state_dict:
            super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
        else:
            self.inv_freq.copy_(1. / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)).to(self.precision))

    def forward(self, x, seq_dim=1, seq_len=None):
        if seq_len is None:
            seq_len = x.shape[seq_dim]
        if self.max_seq_len_cached is None or (seq_len > self.max_seq_len_cached):
            self.max_seq_len_cached = None if self.learnable else seq_len
            t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
            freqs = torch.einsum('i,j->ij', t, self.inv_freq.to(x.device))
            # Different from paper, but it uses a different permutation in order to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            if self.precision == torch.bfloat16 or self.precision == torch.half:
                emb = emb.float()
            # [sx, 1 (b * np), hn]
            cos_cached = emb.cos()[:, None, :]
            sin_cached = emb.sin()[:, None, :]
            if self.precision == torch.bfloat16:
                cos_cached = cos_cached.bfloat16()
                sin_cached = sin_cached.bfloat16()
            elif self.precision == torch.half:
                cos_cached = cos_cached.half()
                sin_cached = sin_cached.half()
            if self.learnable:
                return cos_cached, sin_cached
            self.cos_cached, self.sin_cached = cos_cached, sin_cached
        return self.cos_cached[:seq_len, ...], self.sin_cached[:seq_len, ...]

def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), dim=x1.ndim - 1)  # dim=-1 triggers a bug in earlier torch versions

def assert_dim_check(tensor, ndim=None, shape=None):
    if ndim is not None:
        assert tensor.ndim == ndim, f"Exepct tensor.ndim={ndim}. gut got tensor.shape={tensor.shape}"
    if shape is not None:
        assert list(tensor.shape) == list(shape), f"Exepct tensor.shape={shape}. gut got tensor.shape={tensor.shape}"

def apply_rotary_pos_emb_index_torch(q, k, cos, sin, position_id):  # jitting fails with bf16
    # position_id: [sq, b], q, k: [sq, b, np, hn], cos: [sq, 1, hn] -> [sq, b, 1, hn]
    cos, sin = F.embedding(position_id, cos.squeeze(1)).unsqueeze(2), \
               F.embedding(position_id, sin.squeeze(1)).unsqueeze(2)
    q, k = (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
    return q, k

class RMSNorm(torch.nn.Module):
    def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None, **kwargs):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.empty(normalized_shape, device=device, dtype=dtype))
        self.eps = eps

    def forward(self, hidden_states: torch.Tensor):
        input_dtype = hidden_states.dtype
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.eps)

        return (self.weight * hidden_states).to(input_dtype)

class CoreAttention(torch.nn.Module):
    def __init__(self, config: xTrimoPGLMConfig, layer_number):
        super(CoreAttention, self).__init__()

        self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32
        if self.apply_query_key_layer_scaling:
            self.attention_softmax_in_fp32 = True
        self.layer_number = max(1, layer_number)

        projection_size = config.kv_channels * config.num_attention_heads

        # Per attention head and per partition values.
        self.hidden_size_per_partition = projection_size
        self.hidden_size_per_attention_head = projection_size // config.num_attention_heads
        self.num_attention_heads_per_partition = config.num_attention_heads

        coeff = None
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        if self.apply_query_key_layer_scaling:
            coeff = self.layer_number
            self.norm_factor *= coeff
        self.coeff = coeff

        self.attention_dropout = torch.nn.Dropout(config.attention_dropout)

        self.is_causal = config.is_causal
        self.use_pytorch_sdpa = config.use_pytorch_sdpa
    
    def forward(self, query_layer, key_layer, value_layer, attention_mask):
        # query_layer, key_layer, value_layer: [seq_len, batch_size, num_heads, head_dim]
        # import pdb; pdb.set_trace();
        pytorch_major_version = int(torch.__version__.split('.')[0])
        # assert pytorch_major_version >= 2, f"Expect PyTorch version > 2.0"
        if pytorch_major_version >= 2 and self.use_pytorch_sdpa:
            dropout_p = self.attention_dropout.p if self.training else 0
            # [seq_len, batch_size, num_heads, head_dim] -> [batch_size, num_heads, seq_len, head_dim]
            query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]]
            # import pdb; pdb.set_trace();
            if attention_mask is None and query_layer.shape[2] == key_layer.shape[2]:
                # context_layer: [batch_size, num_heads, seq_len, head_dim]
                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, is_causal=self.is_causal, dropout_p=dropout_p)
            else:
                if (attention_mask is not None) and (attention_mask.dtype == torch.bool):
                    attention_mask = attention_mask.logical_not() ## DO NOT inplace operation!!!!
                context_layer = torch.nn.functional.scaled_dot_product_attention(query_layer, key_layer, value_layer, attention_mask, dropout_p=dropout_p)
            # [batch_size, num_heads, seq_len, head_dim] -> [seq_len, batch_size, num_heads, head_dim]
            context_layer = context_layer.permute(2, 0, 1, 3)
            # [seq_len, batch_size, 2560]
            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
            context_layer = context_layer.reshape(*new_context_layer_shape)
        else:
            # Raw attention scores

            # [b, np, sq, sk]
            output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0))

            # [sq, b, np, hn] -> [sq, b * np, hn]
            query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1)
            # [sk, b, np, hn] -> [sk, b * np, hn]
            key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1)

            # preallocting input tensor: [b * np, sq, sk]
            matmul_input_buffer = torch.empty(
                output_size[0] * output_size[1], output_size[2], output_size[3], dtype=query_layer.dtype,
                device=query_layer.device
            )

            # Raw attention scores. [b * np, sq, sk]
            matmul_result = torch.baddbmm(
                matmul_input_buffer,
                query_layer.transpose(0, 1),  # [b * np, sq, hn]
                key_layer.transpose(0, 1).transpose(1, 2),  # [b * np, hn, sk]
                beta=0.0,
                alpha=(1.0 / self.norm_factor),
            )

            # change view to [b, np, sq, sk]
            attention_scores = matmul_result.view(*output_size)

            # ===========================
            # Attention probs and dropout
            # ===========================

            # attention scores and attention mask [b, np, sq, sk]
            if self.attention_softmax_in_fp32:
                attention_scores = attention_scores.float()
            if self.coeff is not None:
                attention_scores = attention_scores * self.coeff
            if self.is_causal and attention_mask is None and attention_scores.shape[2] == attention_scores.shape[3]:
                attention_mask = torch.ones(output_size[0], 1, output_size[2], output_size[3],
                                            device=attention_scores.device, dtype=torch.bool)
                attention_mask.tril_()
                attention_mask = ~attention_mask
            if attention_mask is not None:
                attention_scores = attention_scores.masked_fill(attention_mask, float("-inf"))
            attention_probs = F.softmax(attention_scores, dim=-1)
            attention_probs = attention_probs.type_as(value_layer)

            # This is actually dropping out entire tokens to attend to, which might
            # seem a bit unusual, but is taken from the original Transformer paper.
            attention_probs = self.attention_dropout(attention_probs)
            # =========================
            # Context layer. [sq, b, hp]
            # =========================

            # value_layer -> context layer.
            # [sk, b, np, hn] --> [b, np, sq, hn]

            # context layer shape: [b, np, sq, hn]
            output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3))
            # change view [sk, b * np, hn]
            value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1)
            # change view [b * np, sq, sk]
            attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1)
            # matmul: [b * np, sq, hn]
            context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
            # change view [b, np, sq, hn]
            context_layer = context_layer.view(*output_size)
            # [b, np, sq, hn] --> [sq, b, np, hn]
            context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
            # [sq, b, np, hn] --> [sq, b, hp]
            new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
            context_layer = context_layer.view(*new_context_layer_shape)

        return context_layer


class SelfAttention(torch.nn.Module):
    """Parallel self-attention layer abstract class.

    Self-attention layer takes input with size [s, b, h]
    and returns output of the same size.
    """

    def __init__(self, config: xTrimoPGLMConfig, layer_number, device=None):
        super(SelfAttention, self).__init__()
        self.layer_number = max(1, layer_number)

        self.projection_size = config.kv_channels * config.num_attention_heads

        # Per attention head and per partition values.
        self.hidden_size_per_attention_head = self.projection_size // config.num_attention_heads
        self.num_attention_heads_per_partition = config.num_attention_heads

        self.multi_query_attention = config.multi_query_attention
        self.qkv_hidden_size = 3 * self.projection_size
        if self.multi_query_attention:
            self.num_multi_query_groups_per_partition = config.multi_query_group_num
            self.qkv_hidden_size = (
                    self.projection_size + 2 * self.hidden_size_per_attention_head * config.multi_query_group_num
            )
        self.query_key_value = nn.Linear(config.hidden_size, self.qkv_hidden_size,
                                         bias=config.add_bias_linear or config.add_qkv_bias,
                                         device=device, **_config_to_kwargs(config)
                                         )

        self.core_attention = CoreAttention(config, self.layer_number)

        # Output.
        self.dense = nn.Linear(self.projection_size, config.hidden_size, bias=config.add_bias_linear, device=device, **_config_to_kwargs(config))
        
        self.rotary_embedding_2d = config.rotary_embedding_2d
        # dim, base=10000, precision=torch.half, learnable=False
        self.rotary_emb = RotaryEmbedding(self.hidden_size_per_attention_head // 2 if self.rotary_embedding_2d else self.hidden_size_per_attention_head, 
                                          base=10000, precision=config.torch_dtype, learnable=False)


    def forward(
            self, hidden_states, attention_mask, position_ids, kv_cache=None, use_cache=True
    ):
        # hidden_states: [sq, b, h]

        # =================================================
        # Pre-allocate memory for key-values for inference.
        # =================================================
        # =====================
        # Query, Key, and Value
        # =====================

        # Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
        mixed_x_layer = self.query_key_value(hidden_states)

        if self.multi_query_attention:
            (query_layer, key_layer, value_layer) = mixed_x_layer.split(
                [
                    self.num_attention_heads_per_partition * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                    self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head,
                ],
                dim=-1,
            )
            query_layer = query_layer.view(
                query_layer.size()[:-1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
            )
            key_layer = key_layer.view(
                key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
            )
            value_layer = value_layer.view(
                value_layer.size()[:-1]
                + (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
            )
        else:
            new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head)
            mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
            # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
            (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)

        # apply relative positional encoding (rotary embedding)
        if position_ids is not None: # [seq_len, 2, batch_size, 32, 2]
            
            if self.rotary_embedding_2d:
                q1, q2 = query_layer.chunk(2, dim=(query_layer.ndim - 1)) # 32
                k1, k2 = key_layer.chunk(2, dim=(key_layer.ndim - 1))
                # import pdb; pdb.set_trace();
                cos, sin = self.rotary_emb(q1, seq_len=position_ids.max() + 1) # 32
                position_ids, block_position_ids = \
                    position_ids[:, 0, :].transpose(0, 1).contiguous(), \
                    position_ids[:, 1, :].transpose(0, 1).contiguous()
                q1, k1 = apply_rotary_pos_emb_index_torch(q1, k1, cos, sin, position_ids)
                q2, k2 = apply_rotary_pos_emb_index_torch(q2, k2, cos, sin, block_position_ids)
                query_layer = torch.concat([q1, q2], dim=(q1.ndim - 1))
                key_layer = torch.concat([k1, k2], dim=(k1.ndim - 1))
            else:
                # [b, sq] -> [sq, b]
                position_ids = position_ids.transpose(0, 1)
                cos, sin = self.rotary_emb(value_layer, seq_len=position_ids.max() + 1)
                query_layer, key_layer = apply_rotary_pos_emb_index_torch(query_layer, key_layer, cos, sin, position_ids)

        # adjust key and value for inference
        if kv_cache is not None:
            cache_k, cache_v = kv_cache
            key_layer = torch.cat((cache_k, key_layer), dim=0)
            value_layer = torch.cat((cache_v, value_layer), dim=0)
        if use_cache:
            kv_cache = (key_layer, value_layer)
        else:
            kv_cache = None

        if self.multi_query_attention:
            key_layer = key_layer.unsqueeze(-2)
            key_layer = key_layer.expand(-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)
            key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))
            value_layer = value_layer.unsqueeze(-2)
            value_layer = value_layer.expand(-1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1)
            value_layer = value_layer.contiguous().view(value_layer.size()[:2] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head))

        # ==================================
        # core attention computation
        # ==================================

        context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) # context_layer: [seq_len, batch_size, num_heads*head_dim]
        output = self.dense(context_layer)
        # =================
        # Output. [sq, b, h]
        # =================

        # output = context_layer @ self.dense.weight.T + self.dense.bias
        return output, kv_cache


def _config_to_kwargs(args):
    common_kwargs = {
        "dtype": args.torch_dtype,
    }
    return common_kwargs


class MLP(torch.nn.Module):
    """MLP.

    MLP will take the input with h hidden state, project it to 4*h
    hidden dimension, perform nonlinear transformation, and project the
    state back into h hidden dimension.
    """

    def __init__(self, config: xTrimoPGLMConfig, device=None):
        super(MLP, self).__init__()

        self.add_bias = config.add_bias_linear
        self.moe = config.moe
        self.num_experts = config.num_experts
        self.experts_per_token = config.experts_per_token # 2

        # Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
        self.dense_h_to_4h = nn.Linear(
            config.hidden_size,
            config.ffn_hidden_size * 2,
            bias=self.add_bias,
            device=device,
            **_config_to_kwargs(config)
        )

        def swiglu(x):
           x = torch.chunk(x, 2, dim=-1)
           return x[0] * F.silu(x[1])

        def geglu(x):
            x = torch.chunk(x, 2, dim=-1)
            return x[0] * F.gelu(x[1])

        if config.glu_activation == 'geglu':
            self.activation_func = geglu
        elif config.glu_activation == 'swiglu':
            self.activation_func = swiglu
        else:
            assert RuntimeError(f"Unsupported glu_activation: {config.glu_activation}")

        # Project back to h.
        self.dense_4h_to_h = nn.Linear(
            config.ffn_hidden_size,
            config.hidden_size,
            bias=self.add_bias,
            device=device,
            **_config_to_kwargs(config)
        )

        if self.moe:
            assert self.num_experts > 1
            del self.dense_h_to_4h
            del self.dense_4h_to_h
            self.router = nn.Linear(
                config.hidden_size,
                config.num_experts,
                bias=False,
                device=device,
                dtype=torch.float32
            )
            for i in range(0, self.num_experts):
                self.register_module(f"dense_h_to_4h_{i}", nn.Linear(
                    config.hidden_size,
                    config.ffn_hidden_size * 2,
                    bias=self.add_bias,
                    device=device,
                    **_config_to_kwargs(config)
                ))
                self.register_module(f"dense_4h_to_h_{i}", nn.Linear(
                    config.ffn_hidden_size,
                    config.hidden_size,
                    bias=self.add_bias,
                    device=device,
                    **_config_to_kwargs(config)
                ))

    def moe_forward(self, hidden_states, expert_idx):
        intermediate_parallel = getattr(self, f"dense_h_to_4h_{expert_idx}")(hidden_states) 
        intermediate_parallel = self.activation_func(intermediate_parallel) 
        output = getattr(self, f"dense_4h_to_h_{expert_idx}")(intermediate_parallel) 
        return output

    def forward(self, hidden_states):
        if self.moe:
            # import pdb; pdb.set_trace();
            s, b, n = hidden_states.shape
            dtype = hidden_states.dtype
            hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
            route = self.router(hidden_states).to(dtype)

            weights, selected_experts = torch.topk(route, self.experts_per_token)
            weights = F.softmax(weights, dim=1, dtype=torch.float).to(hidden_states.dtype)
            output = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
            for expert_idx in range(self.num_experts):
                batch_idx, nth_expert = torch.where(selected_experts == expert_idx)
                if nth_expert.shape[0] == 0:
                    continue
                cur_out = self.moe_forward(hidden_states[batch_idx], expert_idx)
                output[batch_idx] += weights[batch_idx, nth_expert, None] * cur_out
            output = output.reshape(s, b, n)
        else:
            # [s, b, 4hp]
            intermediate_parallel = self.dense_h_to_4h(hidden_states)
            intermediate_parallel = self.activation_func(intermediate_parallel)
            # [s, b, h]
            output = self.dense_4h_to_h(intermediate_parallel)
        return output

class xTrimoPGLMBlock(torch.nn.Module):
    """A single transformer layer.

    Transformer layer takes input with size [s, b, h] and returns an
    output of the same size.
    """

    def __init__(self, config: xTrimoPGLMConfig, layer_number, device=None):
        super(xTrimoPGLMBlock, self).__init__()
        self.layer_number = layer_number

        self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm

        self.fp32_residual_connection = config.fp32_residual_connection

        LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
        # Layernorm on the input data.
        self.input_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon)

        # Self attention.
        self.self_attention = SelfAttention(config, layer_number, device=device)
        self.hidden_dropout = config.hidden_dropout

        # Layernorm on the attention output
        self.post_attention_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon)

        # MLP
        self.mlp = MLP(config, device=device)

        self.deepnorm_coeff = get_deepnorm_coefficients(config) if config.deepnorm else None

    def forward(
            self, hidden_states, attention_mask, position_ids, kv_cache=None, use_cache=True,
    ):
        # hidden_states: [s, b, h]
        # Layer norm at the beginning of the transformer layer.
        layernorm_output = self.input_layernorm(hidden_states)
        # Self attention.
        attention_output, kv_cache = self.self_attention(
            layernorm_output,
            attention_mask,
            position_ids, # [batch_size, 2, seq_len, 32, 2]
            kv_cache=kv_cache,
            use_cache=use_cache
        )

        # Residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = hidden_states

        layernorm_input = torch.nn.functional.dropout(attention_output, p=self.hidden_dropout, training=self.training)
        if self.deepnorm_coeff is not None: 
            layernorm_input = residual*self.deepnorm_coeff.alpha + layernorm_input
        else:
            layernorm_input = residual + layernorm_input

        # Layer norm post the self attention.
        layernorm_output = self.post_attention_layernorm(layernorm_input)

        # MLP.
        mlp_output = self.mlp(layernorm_output)

        # Second residual connection.
        if self.apply_residual_connection_post_layernorm:
            residual = layernorm_output
        else:
            residual = layernorm_input

        output = torch.nn.functional.dropout(mlp_output, p=self.hidden_dropout, training=self.training)
        if self.deepnorm_coeff is not None: 
            output = residual*self.deepnorm_coeff.alpha + output
        else:
            #print(f"2 self.deepnorm_coeff is None")
            output = residual + output

        return output, kv_cache


class xTrimoPGLMTransformer(torch.nn.Module):
    """Transformer class."""

    def __init__(self, config: xTrimoPGLMConfig, device=None):
        super(xTrimoPGLMTransformer, self).__init__()

        self.fp32_residual_connection = config.fp32_residual_connection
        self.post_layer_norm = config.post_layer_norm

        # Number of layers.
        self.num_layers = config.num_layers

        # Transformer layers.
        def build_layer(layer_number):
            return xTrimoPGLMBlock(config, layer_number, device=device)

        self.layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layers)])

        if self.post_layer_norm:
            LayerNormFunc = RMSNorm if config.rmsnorm else LayerNorm
            # Final layer norm before output.
            self.final_layernorm = LayerNormFunc(config.hidden_size, eps=config.layernorm_epsilon)

        self.gradient_checkpointing = False

    def _get_layer(self, layer_number):
        return self.layers[layer_number]

    def forward(
            self, hidden_states, attention_mask, position_ids, kv_caches=None,
            use_cache: Optional[bool] = True,
            output_hidden_states: Optional[bool] = False,
    ):
        if not kv_caches:
            kv_caches = [None for _ in range(self.num_layers)]
        presents = () if use_cache else None
        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        all_self_attentions = None
        all_hidden_states = () if output_hidden_states else None
        for index in range(self.num_layers):
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            layer = self._get_layer(index)
            if self.gradient_checkpointing and self.training and torch.is_grad_enabled():
                layer_ret = get_checkpoint_fn()(
                    layer,
                    hidden_states,
                    attention_mask,
                    position_ids,
                    kv_caches[index],
                    use_cache
                )
            else:
                layer_ret = layer(
                    hidden_states,
                    attention_mask,
                    position_ids,
                    kv_cache=kv_caches[index],
                    use_cache=use_cache
                )
            hidden_states, kv_cache = layer_ret
            if use_cache:
                presents = presents + (kv_cache,)


        # Final layer norm.
        if self.post_layer_norm:
            hidden_states = self.final_layernorm(hidden_states)

        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        return hidden_states, presents, all_hidden_states, all_self_attentions


class xTrimoPGLMPreTrainedModel(PreTrainedModel):
    """
    An abstract class to handle weights initialization and
    a simple interface for downloading and loading pretrained models.
    """

    is_parallelizable = False
    supports_gradient_checkpointing = True
    config_class = xTrimoPGLMConfig
    base_model_prefix = "transformer"
    _no_split_modules = ["xTrimoPGLMBlock"]

    _quantized = False


    def get_masks(self, input_ids, past_key_values, padding_mask=None, is_causal=True):
        batch_size, seq_length = input_ids.shape
        full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=input_ids.device)
        if is_causal:
            full_attention_mask.tril_()
        past_length = 0
        if past_key_values:
            past_length = past_key_values[0][0].shape[0]
        if past_length:
            full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
                                                        device=input_ids.device), full_attention_mask), dim=-1)
        if padding_mask is not None:
            full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
        if not past_length and padding_mask is not None:
            full_attention_mask -= padding_mask.unsqueeze(-1) - 1
        full_attention_mask = (full_attention_mask < 0.5).bool()
        full_attention_mask.unsqueeze_(1)
        return full_attention_mask

    def get_position_ids(self, input_ids, device, context_length=0):
        batch_size, seq_length = input_ids.shape
        if self.config.rotary_embedding_2d:
            if self.config.is_causal: # 100b model
                position_ids_1 = torch.zeros(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
                position_ids_2 = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
                position_ids   = torch.stack([position_ids_1, position_ids_2], axis=1) # [batch_size, 2, seq_len]       
            else:
                position_ids_1 = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
                position_ids_2 = torch.zeros(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
                position_ids   = torch.stack([position_ids_1, position_ids_2], axis=1) # [batch_size, 2, seq_len]
        else:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1) # [batch_size, 1, seq_len]
        return position_ids

    def _set_gradient_checkpointing(self, module, value=False):
        if isinstance(module, xTrimoPGLMTransformer):
            module.gradient_checkpointing = value

    
    # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
    def _init_weights(self, module):
        std = self.config.initializer_range
        """Initialize the weights"""
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def quantize(self, weight_bit_width: int, empty_init=True, device=None):
        if self._quantized:
            print(f"Model has been quantized...")
            return
        self.transformer.encoder = quantize(self.transformer.encoder, weight_bit_width, empty_init, device)
        self._quantized = True
        return self

class Embedding(torch.nn.Module):
    """Language model embeddings."""

    def __init__(self, config: xTrimoPGLMConfig, device=None):
        super(Embedding, self).__init__()

        self.hidden_size = config.hidden_size
        # Word embeddings (parallel).
        self.word_embeddings = nn.Embedding(
            config.padded_vocab_size,
            self.hidden_size,
            dtype=config.torch_dtype,
            device=device
        )
        self.fp32_residual_connection = config.fp32_residual_connection


    def forward(self, input_ids):
        # Embeddings.
        words_embeddings = self.word_embeddings(input_ids)
        embeddings = words_embeddings
        # Data format change to avoid explicit tranposes : [b s h] --> [s b h].
        embeddings = embeddings.transpose(0, 1).contiguous()
        # If the input flag for fp32 residual connection is set, convert for float.
        if self.fp32_residual_connection:
            embeddings = embeddings.float()
        return embeddings

class xTrimoPGLMModel(xTrimoPGLMPreTrainedModel):
    def __init__(self, config: xTrimoPGLMConfig, device=None, empty_init=True):
        super().__init__(config)
        if empty_init:
            init_method = skip_init
        else:
            init_method = default_init
        init_kwargs = {}
        if device is not None:
            init_kwargs["device"] = device
        self.embedding = init_method(Embedding, config, **init_kwargs)
        self.num_layers = config.num_layers
        self.multi_query_group_num = config.multi_query_group_num
        self.kv_channels = config.kv_channels

        # Rotary positional embeddings
        self.seq_length = config.seq_length
        rotary_dim = (
            config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
        )

        # self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, base=10000, precision=config.torch_dtype, learnable=False)
        self.encoder = init_method(xTrimoPGLMTransformer, config, **init_kwargs)
        
        self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
                                        dtype=config.torch_dtype, **init_kwargs)

    def get_input_embeddings(self):
        return self.embedding.word_embeddings

    def set_input_embeddings(self, value):
        self.embedding.word_embeddings = value

    def forward(
            self,
            input_ids,
            position_ids: Optional[torch.Tensor] = None, # position_ids: [batch_size, 2, seq_len]
            attention_mask: Optional[torch.BoolTensor] = None,
            full_attention_mask: Optional[torch.BoolTensor] = None,
            past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ):
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        if self.config.is_causal:
            use_cache = use_cache if use_cache is not None else self.config.use_cache
        else:
            use_cache = False
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        batch_size, seq_length = input_ids.shape

        if inputs_embeds is None:
            inputs_embeds = self.embedding(input_ids)

        if full_attention_mask is None:
            if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
                full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
        # Run encoder.
        hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
            inputs_embeds, full_attention_mask, position_ids=position_ids,
            kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
        )

        if not return_dict:
            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )


class xTrimoPGLMForMaskedLM(xTrimoPGLMPreTrainedModel):
    def __init__(self, config: xTrimoPGLMConfig, empty_init=True, device=None):
        super().__init__(config)

        self.max_sequence_length = config.max_length
        self.transformer = xTrimoPGLMModel(config, empty_init=empty_init, device=device)
        self.config = config
        if self.config.quantization_bit:
            print(f"Begin Quantization to {self.config.quantization_bit} bit")
            self.quantize(self.config.quantization_bit, empty_init=True, device=device)

    def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            return_last_logit: Optional[bool] = None,
            return_last_hidden_state: Optional[bool] = None
    ):
        if self.config.is_causal:
            use_cache = use_cache if use_cache is not None else self.config.use_cache
        else:
            use_cache = False
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if position_ids is None:
            position_ids = self.get_position_ids(input_ids, device=input_ids.device)

        full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask, is_causal=self.config.is_causal)

        transformer_outputs = self.transformer(
            input_ids=input_ids,
            position_ids=position_ids, # position_ids: [batch_size, 2, seq_len]
            full_attention_mask=full_attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = transformer_outputs[0]
        if return_last_logit:
            hidden_states = hidden_states[-1:]
        lm_logits = self.transformer.output_layer(hidden_states)
        lm_logits = lm_logits.transpose(0, 1).contiguous()

        masked_lm_loss = None
        if labels is not None:
            lm_logits = lm_logits.to(torch.float32)

            # Flatten the tokens
            loss_fct = CrossEntropyLoss(ignore_index=-100) # -100 for padding token.
            masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))

            lm_logits = lm_logits.to(hidden_states.dtype)
            loss = loss.to(hidden_states.dtype)

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output
        return MaskedLMOutput(
            loss = masked_lm_loss,
            logits=lm_logits,
            hidden_states=transformer_outputs.last_hidden_state if return_last_hidden_state else transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )




class xTrimoPGLMForSequenceClassification(xTrimoPGLMPreTrainedModel):
    def __init__(self, config: xTrimoPGLMConfig, empty_init=True, device=None):
        super().__init__(config)
        self.config = config
        self.num_labels = config.num_labels
        
        self.transformer = xTrimoPGLMModel(config, empty_init=empty_init, device=device)
        self.classifier = xTrimoPGLMClassificationHead(config)
        if self.config.quantization_bit:
            print(f"Begin Quantization to {self.config.quantization_bit} bit")
            self.quantize(self.config.quantization_bit, empty_init=True, device=device)

    def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            return_last_logit: Optional[bool] = None,
            return_last_hidden_state: Optional[bool] = None,
            **kwargs
    ) -> Union[Tuple, SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        if self.config.is_causal:
            use_cache = use_cache if use_cache is not None else self.config.use_cache
        else:
            use_cache = False
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if position_ids is None:
            position_ids = self.get_position_ids(input_ids, device=input_ids.device)

        full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask, is_causal=self.config.is_causal)

        transformer_outputs = self.transformer(
            input_ids=input_ids,
            position_ids=position_ids, # position_ids: [batch_size, 2, seq_len]
            full_attention_mask=full_attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        if self.config.add_special_tokens:
            hidden_states = transformer_outputs[0][:-1] # get rid of <eos> token
        else:
            hidden_states = transformer_outputs[0]
        logits = self.classifier(hidden_states, add_pooling=True)
        loss = None
        if labels is not None:
            labels = labels.to(logits.device)

            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + transformer_outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

class xTrimoPGLMForTokenClassification(xTrimoPGLMPreTrainedModel):
    def __init__(self, config: xTrimoPGLMConfig, empty_init=True, device=None):
        super().__init__(config)
        self.config = config
        self.num_labels = config.num_labels
        
        self.transformer = xTrimoPGLMModel(config, empty_init=empty_init, device=device)
        if config.task_modality == "token":
            self.classifier = xTrimoPGLMClassificationHead(config)
        elif config.task_modality == 'pair':
            self.classifier = xTrimoPGLMContactHead(config)

        self.quantized = False

        if self.config.quantization_bit:
            print(f"Begin Quantization to {self.config.quantization_bit} bit")
            self.quantize(self.config.quantization_bit, empty_init=True, device=device)


    def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            return_last_logit: Optional[bool] = None,
            return_last_hidden_state: Optional[bool] = None,
            **kwargs
    ) -> Union[Tuple, SequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        if self.config.is_causal:
            use_cache = use_cache if use_cache is not None else self.config.use_cache
        else:
            use_cache = False
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if position_ids is None:
            position_ids = self.get_position_ids(input_ids, device=input_ids.device)

        full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask, is_causal = self.config.is_causal)

        transformer_outputs = self.transformer(
            input_ids=input_ids,
            position_ids=position_ids, # position_ids: [batch_size, 2, seq_len]
            full_attention_mask=full_attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        if self.config.add_special_tokens:
            hidden_states = transformer_outputs[0][:-1] # get rid of <eos> token
        else:
            hidden_states = transformer_outputs[0]

        logits = self.classifier(hidden_states, add_pooling=False)
        loss = None
        if labels is not None:
            labels = labels.to(logits.device)
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            output = (logits,) + transformer_outputs[2:]
            return ((loss,) + output) if loss is not None else output


        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )



class xTrimoPGLMClassificationHead(nn.Module):
    """Head for classification tasks."""
    def __init__(self, config):
        super().__init__()
        self.activation_func = config.activation_func
        self.layers = torch.nn.ModuleList()
        last_size = config.hidden_size
        for sz in config.inter_hidden_size:
            this_layer = torch.nn.Linear(last_size, sz, bias=config.bias)
            last_size = sz
            self.layers.append(this_layer)
    
    def forward(self, 
                input_features,
                add_pooling: Optional[bool] = True
                ):
        # [s, b, h] -> [b, s ,h]
        input_features = input_features.transpose(0,1).contiguous()
        if add_pooling:
            # [b, h]
            input_features = torch.mean(input_features, dim = 1)
        for i, layer in enumerate(self.layers):
            if i > 0:
                input_features = self.activation_func(input_features)
            input_features = layer(input_features)
        return input_features

class xTrimoPGLMContactHead(nn.Module):
    """Head for sentence-level classification tasks."""
    def __init__(self, config):
        super().__init__()
        self.activation_func = config.activation_func
        self.layers = torch.nn.ModuleList()
        last_size = config.hidden_size * 2
        for sz in config.inter_hidden_size:
            this_layer = torch.nn.Linear(last_size, sz, bias=config.bias)
            last_size = sz
            self.layers.append(this_layer)
    
    def outer_concat(self, x):
        batch_size, seq_len, features = x.shape
        
        # Permute to [batch_size, features, seq_len]
        x = x.permute(0, 2, 1)
        
        # Introduce new dimensions for broadcasting
        x_1 = x[:, None, :, :, None]  # [batch_size, 1, features, seq_len, 1]
        x_2 = x[:, None, :, None, :]  # [batch_size, 1, features, 1, seq_len]
        
        # Repeat along new dimensions
        x_1 = x_1.repeat(1, 1, 1, 1, seq_len)  # [batch_size, 1, features, seq_len, seq_len]
        x_2 = x_2.repeat(1, 1, 1, seq_len, 1)  # [batch_size, 1, features, seq_len, seq_len]
        
        # Concatenate along the second dimension
        x = torch.cat((x_1, x_2), dim=1)  # [batch_size, 2, features, seq_len, seq_len]
        
        # Get lower triangular indices
        I, J = torch.tril_indices(seq_len, seq_len, -1)
        
        # Symmetrize
        x[:, :, :, I, J] = x[:, :, :, J, I]
        
        # Permute to desired shape and make contiguous
        x = x.permute(0, 3, 4, 2, 1).contiguous()  # [batch_size, seq_len, seq_len, features, 2]
        
        # Reshape to combine the last two dimensions
        x = x.view(batch_size, seq_len, seq_len, features * 2)  # [batch_size, seq_len, seq_len, features * 2]
        
        return x

    def forward(self, 
                input_features,
                add_pooling: Optional[bool] = True
                ):
        # [s, b, h] -> [b, s ,h]
        input_features = input_features.transpose(0,1).contiguous()
        input_features = self.outer_concat(input_features)
        for i, layer in enumerate(self.layers):
            if i > 0:
                input_features = self.activation_func(input_features)
            input_features = layer(input_features)
        return input_features  





class xTrimoPGLMForCasualLM(xTrimoPGLMPreTrainedModel):
    def __init__(self, config: xTrimoPGLMConfig, empty_init=True, device=None):
        super().__init__(config)

        self.max_sequence_length = config.max_length
        self.transformer = xTrimoPGLMModel(config, empty_init=empty_init, device=device)
        self.config = config
        if self.config.quantization_bit:
            print(f"Begin Quantization to {self.config.quantization_bit} bit")
            self.quantize(self.config.quantization_bit, empty_init=True, device=device)

    def _update_model_kwargs_for_generation(
            self,
            outputs: ModelOutput,
            model_kwargs: Dict[str, Any],
            is_encoder_decoder: bool = False,
            standardize_cache_format: bool = False,
    ) -> Dict[str, Any]:
        # update past_key_values
        model_kwargs["past_key_values"] = self._extract_past_from_model_output(
            outputs, standardize_cache_format=standardize_cache_format
        )

        # update attention mask
        if "attention_mask" in model_kwargs:
            attention_mask = model_kwargs["attention_mask"]
            model_kwargs["attention_mask"] = torch.cat(
                [attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
            )

        # update position ids
        if "position_ids" in model_kwargs:
            position_ids = model_kwargs["position_ids"]
            new_position_id = position_ids[..., -1:].clone() # [batch_size, 2, 1]
            if self.config.rotary_embedding_2d:
                new_position_id[:, 1] += 1 # Only update the 2nd dimension
            else:
                new_position_id[:] += 1
            model_kwargs["position_ids"] = torch.cat(
                [position_ids, new_position_id], dim=-1
            ) # [batch_size, 2, seq_len+1]

        model_kwargs["is_first_forward"] = False
        return model_kwargs

    def prepare_inputs_for_generation(
            self,
            input_ids: torch.LongTensor,
            past_key_values: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            is_first_forward: bool = True,
            **kwargs
    ) -> dict:
        # only last token for input_ids if past is not None
        if position_ids is None:
            position_ids = self.get_position_ids(input_ids, device=input_ids.device) # position_ids: [batch_size, 2, seq_len]
        if not is_first_forward:
            if past_key_values is not None:
                position_ids = position_ids[..., -1:]
                input_ids = input_ids[:, -1:]
        return {
            "input_ids": input_ids,
            "past_key_values": past_key_values,
            "position_ids": position_ids,
            "attention_mask": attention_mask,
            "return_last_logit": True,
            "use_cache": use_cache
        }

    def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            return_last_logit: Optional[bool] = False
    ):
        if self.config.is_causal:
            use_cache = use_cache if use_cache is not None else self.config.use_cache
        else:
            use_cache = False

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if position_ids is None:
            position_ids = self.get_position_ids(input_ids, device=input_ids.device)

        transformer_outputs = self.transformer(
            input_ids=input_ids,
            position_ids=position_ids, # position_ids: [batch_size, 2, seq_len]
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )
        hidden_states = transformer_outputs[0]
        if return_last_logit:
            hidden_states = hidden_states[-1:]
        lm_logits = self.transformer.output_layer(hidden_states)
        lm_logits = lm_logits.transpose(0, 1).contiguous()

        loss = None
        if labels is not None:
            lm_logits = lm_logits.to(torch.float32)

            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            lm_logits = lm_logits.to(hidden_states.dtype)
            loss = loss.to(hidden_states.dtype)

        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

    @staticmethod
    def _reorder_cache(
            past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
    ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
        """
        This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
        [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
        beam_idx at every generation step.

        Output shares the same memory storage as `past`.
        """
        return tuple(
            (
                layer_past[0].index_select(1, beam_idx.to(layer_past[0].device)),
                layer_past[1].index_select(1, beam_idx.to(layer_past[1].device)),
            )
            for layer_past in past
        )
        
    @torch.inference_mode()
    def chat(self, tokenizer, query: str,  max_length: int = 256, num_beams=1, do_sample=True, 
            top_p=1.0, temperature=1.0, logits_processor=None, **kwargs):
        if logits_processor is None:
            logits_processor = LogitsProcessorList()
        logits_processor.append(InvalidScoreLogitsProcessor())
        gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
                      "temperature": temperature, "logits_processor": logits_processor, **kwargs}
        inputs = tokenizer.apply_chat_template(query, add_generation_prompt=True, tokenize=True,
                                               return_tensors="pt", return_dict=True)
        position_ids = self.get_position_ids(inputs['input_ids'], device=self.device) # TODO: ADD BATCH
        eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<eop>")]
        inputs["position_ids"] = position_ids
        inputs = inputs.to(self.device)
        outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
        outputs = outputs.tolist()[0][3:] # 3 for generation prompt "<gmask><sop><eos>"
        if outputs[-1] in eos_token_id:
            outputs = outputs[:-1]
        response = tokenizer.decode(outputs)
        return response

    # TODO: fix bug in streaming chat 
    @torch.inference_mode()
    def stream_chat(self, tokenizer, query: str,  max_length: int = 56, num_beams=1, do_sample=True, 
                    top_p=0.8, temperature=0.8, logits_processor=None, past_key_values = None, **kwargs):
        if logits_processor is None:
            logits_processor = LogitsProcessorList()
        logits_processor.append(InvalidScoreLogitsProcessor())
        eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<eop>")]
        gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
                      "temperature": temperature, "logits_processor": logits_processor, **kwargs}
        inputs = tokenizer.apply_chat_template(query, add_generation_prompt=True, tokenize=True,
                                            return_tensors="pt", return_dict=True)
        position_ids = self.get_position_ids(inputs['input_ids'], device=self.device) # TODO: ADD BATCH
        eos_token_id = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<eop>")]
        inputs["position_ids"] = position_ids
        inputs = inputs.to(self.device)
        offset = 3 # 3 for generation prompt
        for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
                                            eos_token_id=eos_token_id, return_past_key_values=False,
                                            **gen_kwargs):
            outputs = outputs.tolist()[0][3:]
            if outputs[-1] in eos_token_id:
                outputs = outputs[:-1]
            # offset = 3 + len(outputs)
            response = tokenizer.decode(outputs)
            if response:
                yield response

    @torch.inference_mode()
    def stream_generate(
            self,
            input_ids,
            generation_config: Optional[GenerationConfig] = None,
            logits_processor: Optional[LogitsProcessorList] = None,
            stopping_criteria: Optional[StoppingCriteriaList] = None,
            prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
            return_past_key_values=False,
            **kwargs,
    ):
        breakpoint()
        batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]

        if generation_config is None:
            generation_config = self.generation_config
        generation_config = copy.deepcopy(generation_config)
        model_kwargs = generation_config.update(**kwargs)
        model_kwargs["use_cache"] = generation_config.use_cache
        bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id

        if isinstance(eos_token_id, int):
            eos_token_id = [eos_token_id]
        eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None

        has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
        if has_default_max_length and generation_config.max_new_tokens is None:
            warnings.warn(
                f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
                "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
                " recommend using `max_new_tokens` to control the maximum length of the generation.",
                UserWarning,
            )
        elif generation_config.max_new_tokens is not None:
            generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
            if not has_default_max_length:
                logger.warn(
                    f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
                    f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
                    "Please refer to the documentation for more information. "
                    "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
                    UserWarning,
                )

        if input_ids_seq_length >= generation_config.max_length:
            input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
            logger.warning(
                f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
                f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
                " increasing `max_new_tokens`."
            )

        # 2. Set generation parameters if not already defined
        logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
        stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()

        logits_processor = self._get_logits_processor(
            generation_config=generation_config,
            input_ids_seq_length=input_ids_seq_length,
            encoder_input_ids=input_ids,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            logits_processor=logits_processor,
        )

        stopping_criteria = self._get_stopping_criteria(
            generation_config=generation_config, stopping_criteria=stopping_criteria
        )
        logits_warper = self._get_logits_warper(generation_config)

        unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
        scores = None
        while True:
            model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
            # forward pass to get next token
            outputs = self(
                **model_inputs,
                return_dict=True,
                output_attentions=False,
                output_hidden_states=False,
            )

            next_token_logits = outputs.logits[:, -1, :]

            # pre-process distribution
            next_token_scores = logits_processor(input_ids, next_token_logits)
            next_token_scores = logits_warper(input_ids, next_token_scores)

            # sample
            probs = nn.functional.softmax(next_token_scores, dim=-1)
            if generation_config.do_sample:
                next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
            else:
                next_tokens = torch.argmax(probs, dim=-1)
            # update generated ids, model inputs, and length for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
            model_kwargs = self._update_model_kwargs_for_generation(
                outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
            )
            unfinished_sequences = unfinished_sequences.mul(
                next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
            )
            if return_past_key_values:
                yield input_ids, outputs.past_key_values
            else:
                yield input_ids
            # stop when each sentence is finished, or if we exceed the maximum length
            if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
                break