import json

import torch
import transformers
from transformers.cache_utils import *
from transformers.models.llama.modeling_llama import *

from .modules.inf_llm import InfLLMGenerator, inf_llm_forward
from .modules.minference_forward import (
    gather_last_q_vertical_slash_topk_v4,
    gather_last_q_vertical_slash_topk_vllm,
    init_minference_parameters,
    minference_forward,
    minference_kv_cache_cpu_forward,
    minference_vllm_forward,
    minference_with_snapkv_forward,
    search_pattern,
    sum_all_diagonal_matrix,
)
from .ops.streaming_kernel import stream_llm_forward


class RotaryEmbeddingESM(torch.nn.Module):
    """
    Rotary position embeddings based on those in
    [RoFormer](https://huggingface.co/docs/transformers/model_doc/roformer). Query and keys are transformed by rotation
    matrices which depend on their relative positions.
    """

    def __init__(
        self,
        dim: int,
        base: Union[int, float] = 10000,
        distance_scale: Union[int, float] = 1,
    ):
        super().__init__()
        self.base = base
        self.distance_scale = distance_scale

        # Generate and save the inverse frequency buffer (non trainable)
        inv_freq = 1.0 / (
            base ** (torch.arange(0, dim, 2, device="cuda", dtype=torch.float32) / dim)
        )
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        self._seq_len_cached = -1
        self._cos_cached = None
        self._sin_cached = None

    def rotate_half(self, x):
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat((-x2, x1), dim=-1)

    def apply_rotary_pos_emb(self, x, length, right, cos, sin):
        dtype = x.dtype
        if cos.dim() == 2:
            cos = cos[right - length : right, :]
            sin = sin[right - length : right, :]
        elif cos.dim() == 3:
            cos = cos[:, right - length : right, :]
            sin = sin[:, right - length : right, :]
        elif cos.dim() == 4:
            cos = cos[:, :, right - length : right, :]
            sin = sin[:, :, right - length : right, :]

        return ((x.float() * cos) + (self.rotate_half(x).float() * sin)).to(dtype)

    def _update_cos_sin_tables(self, x, seq_dim):
        seq_len = x.size(seq_dim)
        if seq_len > self._seq_len_cached:
            self._seq_len_cached = seq_len
            t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.outer(t * self.distance_scale, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1)
            if x.dim() == 2:
                self._cos_cached = emb.cos()
                self._sin_cached = emb.sin()
            elif x.dim() == 3:
                self._cos_cached = emb.cos()[None, :, :]
                self._sin_cached = emb.sin()[None, :, :]
            elif x.dim() == 4:
                self._cos_cached = emb.cos()[None, None, :, :]
                self._sin_cached = emb.sin()[None, None, :, :]
        return self._cos_cached, self._sin_cached

    def _update_cos_sin_tables_len(self, seq_len, device, dim=None):
        if seq_len > self._seq_len_cached:
            if dim is None:
                assert self._cos_cached is not None
                dim = self._cos_cached.dim()

            self._seq_len_cached = seq_len
            t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
            freqs = torch.outer(t * self.distance_scale, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1)
            if dim == 2:
                self._cos_cached = emb.cos()
                self._sin_cached = emb.sin()
            elif dim == 3:
                self._cos_cached = emb.cos()[None, :, :]
                self._sin_cached = emb.sin()[None, :, :]
            elif dim == 4:
                self._cos_cached = emb.cos()[None, None, :, :]
                self._sin_cached = emb.sin()[None, None, :, :]

        return self._cos_cached, self._sin_cached

    def apply_rotary_pos_emb_one_angle(self, x: torch.Tensor, index):
        dtype = x.dtype
        cos, sin = self._update_cos_sin_tables_len(index, x.device)
        if cos.dim() == 2:
            cos = cos[index - 1 : index, :]
            sin = sin[index - 1 : index, :]
        elif cos.dim() == 3:
            cos = cos[:, index - 1 : index, :]
            sin = sin[:, index - 1 : index, :]
        elif cos.dim() == 4:
            cos = cos[:, :, index - 1 : index, :]
            sin = sin[:, :, index - 1 : index, :]

        return ((x.float() * cos) + (self.rotate_half(x).float() * sin)).to(dtype)

    def forward(
        self, q: torch.Tensor, k: torch.Tensor, seq_dim=-2
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        self._cos_cached, self._sin_cached = self._update_cos_sin_tables(
            k, seq_dim=seq_dim
        )
        return (
            self.apply_rotary_pos_emb(
                q, q.size(seq_dim), k.size(seq_dim), self._cos_cached, self._sin_cached
            ),
            self.apply_rotary_pos_emb(
                k, k.size(seq_dim), k.size(seq_dim), self._cos_cached, self._sin_cached
            ),
        )


ATTN_FORWRAD = {
    "streaming": stream_llm_forward,
    "minference": minference_forward,
    "inf_llm": inf_llm_forward,
}


def huggingface_forward(forward):
    def hf_forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask=None,
        position_ids=None,
        past_key_value=None,
        output_attentions: bool = False,
        use_cache: bool = False,
        **kwargs,
    ):
        assert not output_attentions
        ret = forward(
            self,
            hidden_states,
            hidden_states,
            position_ids,
            use_cache,
            past_key_value,
            self.q_proj,
            self.k_proj,
            self.v_proj,
            self.o_proj,
            self.head_dim,
            self.num_heads,
            self.num_key_value_heads,
        )
        if use_cache:
            o, pkv = ret
        else:
            o = ret
            pkv = None

        return o, None, pkv

    return hf_forward


def hf_437_prepare_inputs_for_generation(
    self,
    input_ids,
    past_key_values=None,
    attention_mask=None,
    inputs_embeds=None,
    **kwargs,
):
    if past_key_values is not None:
        if isinstance(past_key_values, transformers.cache_utils.Cache):
            cache_length = past_key_values.get_seq_length()
            past_length = past_key_values.seen_tokens
            max_cache_length = past_key_values.get_max_length()
        else:
            cache_length = past_length = past_key_values[0][0].shape[2]
            max_cache_length = None

        # Keep only the unprocessed tokens:
        # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
        # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
        # input)
        if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
            input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
        # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
        # input_ids based on the past_length.
        elif past_length < input_ids.shape[1]:
            input_ids = input_ids[:, past_length:]
        # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

        # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
        if (
            max_cache_length is not None
            and attention_mask is not None
            and cache_length + input_ids.shape[1] > max_cache_length
        ):
            attention_mask = attention_mask[:, -max_cache_length:]

    position_ids = kwargs.get("position_ids", None)
    if attention_mask is not None and position_ids is None:
        # create position_ids on the fly for batch generation
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        if past_key_values:
            position_ids = position_ids[:, -input_ids.shape[1] :]

    # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
    if inputs_embeds is not None and past_key_values is None:
        model_inputs = {"inputs_embeds": inputs_embeds}
    else:
        model_inputs = {"input_ids": input_ids}

    model_inputs.update(
        {
            "position_ids": position_ids,
            "past_key_values": past_key_values,
            "use_cache": kwargs.get("use_cache"),
            "attention_mask": attention_mask,
        }
    )
    return model_inputs


def prepare_inputs_for_generation(
    self,
    input_ids,
    past_key_values=None,
    attention_mask=None,
    inputs_embeds=None,
    cache_position=None,
    **kwargs,
):
    # With static cache, the `past_key_values` is None
    # TODO joao: standardize interface for the different Cache classes and remove of this if
    has_static_cache = False
    if past_key_values is None:
        past_key_values = getattr(
            getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None
        )
        has_static_cache = past_key_values is not None

    past_length = 0
    if past_key_values is not None:
        if isinstance(past_key_values, transformers.cache_utils.Cache):
            past_length = (
                cache_position[0]
                if cache_position is not None
                else past_key_values.get_seq_length()
            )
            max_cache_length = (
                torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
                if past_key_values.get_max_length() is not None
                else None
            )
            cache_length = (
                past_length
                if max_cache_length is None
                else torch.min(max_cache_length, past_length)
            )
        # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
        else:
            # cache_length = past_length = past_key_values[0][0].shape[2]
            cache_length = past_length = cache_position[0]
            max_cache_length = None

        # Keep only the unprocessed tokens:
        # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
        # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
        # input)
        if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
            input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
        # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
        # input_ids based on the past_length.
        elif past_length < input_ids.shape[1]:
            input_ids = input_ids[:, past_length:]
        # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

        # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
        if (
            max_cache_length is not None
            and attention_mask is not None
            and cache_length + input_ids.shape[1] > max_cache_length
        ):
            attention_mask = attention_mask[:, -max_cache_length:]

    position_ids = kwargs.get("position_ids", None)
    if attention_mask is not None and position_ids is None:
        # create position_ids on the fly for batch generation
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        if past_key_values:
            position_ids = position_ids[:, -input_ids.shape[1] :]

    # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
    if inputs_embeds is not None and past_key_values is None:
        model_inputs = {"inputs_embeds": inputs_embeds}
    else:
        # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
        # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
        # TODO: use `next_tokens` directly instead.
        model_inputs = {"input_ids": input_ids.contiguous()}

    input_length = (
        position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
    )
    if cache_position is None:
        cache_position = torch.arange(
            past_length, past_length + input_length, device=input_ids.device
        )
    else:
        cache_position = cache_position[-input_length:]

    if has_static_cache:
        past_key_values = None

    model_inputs.update(
        {
            "position_ids": position_ids,
            "cache_position": cache_position,
            "past_key_values": past_key_values,
            "use_cache": kwargs.get("use_cache"),
            "attention_mask": attention_mask,
        }
    )
    return model_inputs


def prepare_inputs_for_generation_snapkv(
    self,
    input_ids,
    past_key_values=None,
    attention_mask=None,
    inputs_embeds=None,
    **kwargs,
):
    if past_key_values is None:  # [SnapKV]
        for layer in self.model.layers:
            layer.self_attn.kv_seq_len = 0
    if past_key_values is not None:
        if isinstance(past_key_values, Cache):
            cache_length = past_key_values.get_seq_length()
            past_length = past_key_values.seen_tokens
            max_cache_length = past_key_values.get_max_length()
        else:
            # cache_length = past_length = past_key_values[0][0].shape[2]
            # max_cache_length = None
            cache_length = past_length = self.model.layers[0].self_attn.kv_seq_len
            max_cache_length = None
        # Keep only the unprocessed tokens:
        # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
        # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
        # input)
        if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
            input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
        # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
        # input_ids based on the past_length.
        elif past_length < input_ids.shape[1]:
            input_ids = input_ids[:, past_length:]
        # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

        # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
        if (
            max_cache_length is not None
            and attention_mask is not None
            and cache_length + input_ids.shape[1] > max_cache_length
        ):
            attention_mask = attention_mask[:, -max_cache_length:]

    position_ids = kwargs.get("position_ids", None)
    if attention_mask is not None and position_ids is None:
        # create position_ids on the fly for batch generation
        position_ids = attention_mask.long().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        if past_key_values:
            position_ids = position_ids[:, -input_ids.shape[1] :]

    # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
    if inputs_embeds is not None and past_key_values is None:
        model_inputs = {"inputs_embeds": inputs_embeds}
    else:
        model_inputs = {"input_ids": input_ids}

    model_inputs.update(
        {
            "position_ids": position_ids,
            "past_key_values": past_key_values,
            "use_cache": kwargs.get("use_cache"),
            "attention_mask": attention_mask,
        }
    )
    return model_inputs


def _prepare_decoder_attention_mask_inference(
    self, attention_mask, input_shape, inputs_embeds, past_key_values_length
):
    # [bsz, seq_len]
    if past_key_values_length > 0 and attention_mask is not None:
        attention_mask = torch.cat(
            (
                torch.full(
                    (input_shape[0], past_key_values_length),
                    True,
                    dtype=attention_mask.dtype,
                    device=attention_mask.device,
                ),
                attention_mask,
            ),
            dim=-1,
        )

    if attention_mask is not None and torch.all(attention_mask):
        return None  # This uses the faster call when training with full samples

    return attention_mask


def forward_llama_decoder_layer(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[Tuple[torch.Tensor]] = None,
    output_attentions: Optional[bool] = False,
    use_cache: Optional[bool] = False,
    padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
    """
    Args:
        hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
        attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
            `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under
            returned tensors for more detail.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
            (see `past_key_values`).
        past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
    """

    residual = hidden_states.clone()
    batch, seq_len, embed_dim = hidden_states.shape

    for start_idx in range(0, seq_len, 32000):
        end_idx = min(seq_len, start_idx + 32000)
        hidden_states[:, start_idx:end_idx, :] = self.input_layernorm(
            hidden_states[:, start_idx:end_idx, :]
        )

    # Self Attention
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
        hidden_states=hidden_states,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_value=past_key_value,
        output_attentions=output_attentions,
        use_cache=use_cache,
        padding_mask=padding_mask,
    )
    hidden_states = residual + hidden_states

    # Fully Connected
    for start_idx in range(0, seq_len, 32000):
        end_idx = min(seq_len, start_idx + 32000)
        part_hidden_states = hidden_states[:, start_idx:end_idx, :].clone()
        part_hidden_states = self.post_attention_layernorm(part_hidden_states)
        part_hidden_states = self.mlp(part_hidden_states)
        hidden_states[:, start_idx:end_idx, :] += part_hidden_states

    outputs = (hidden_states,)

    if output_attentions:
        outputs += (self_attn_weights,)

    if use_cache:
        outputs += (present_key_value,)

    return outputs


def forward_llama_model(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
    output_attentions = (
        output_attentions
        if output_attentions is not None
        else self.config.output_attentions
    )
    output_hidden_states = (
        output_hidden_states
        if output_hidden_states is not None
        else self.config.output_hidden_states
    )
    use_cache = use_cache if use_cache is not None else self.config.use_cache

    return_dict = (
        return_dict if return_dict is not None else self.config.use_return_dict
    )

    # retrieve input_ids and inputs_embeds
    if input_ids is not None and inputs_embeds is not None:
        raise ValueError(
            "You cannot specify both input_ids and inputs_embeds at the same time"
        )
    elif input_ids is not None:
        batch_size, seq_length = input_ids.shape[:2]
    elif inputs_embeds is not None:
        batch_size, seq_length = inputs_embeds.shape[:2]
    else:
        raise ValueError("You have to specify either input_ids or inputs_embeds")

    if self.gradient_checkpointing and self.training:
        if use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
            )
            use_cache = False

    seq_length_with_past = seq_length
    past_key_values_length = 0

    if use_cache:
        use_legacy_cache = not isinstance(past_key_values, Cache)
        if use_legacy_cache:
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
        past_key_values_length = past_key_values.get_usable_length(seq_length)
        seq_length_with_past = seq_length_with_past + past_key_values_length

    if position_ids is None:
        device = input_ids.device if input_ids is not None else inputs_embeds.device
        position_ids = torch.arange(
            past_key_values_length,
            seq_length + past_key_values_length,
            dtype=torch.long,
            device=device,
        )
        position_ids = position_ids.unsqueeze(0)

    if inputs_embeds is None:
        inputs_embeds = self.embed_tokens(input_ids)

    if attention_mask is None:
        attention_mask = torch.ones(
            (batch_size, seq_length_with_past),
            dtype=torch.bool,
            device=inputs_embeds.device,
        )
        padding_mask = None
    else:
        if 0 in attention_mask:
            padding_mask = attention_mask
        else:
            padding_mask = None

    attention_mask = self._prepare_decoder_attention_mask(
        attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
    )

    # embed positions
    hidden_states = inputs_embeds

    # decoder layers
    all_hidden_states = () if output_hidden_states else None
    all_self_attns = () if output_attentions else None
    next_decoder_cache = None

    for decoder_layer in self.layers:
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        if self.gradient_checkpointing and self.training:
            layer_outputs = self._gradient_checkpointing_func(
                decoder_layer.__call__,
                hidden_states,
                attention_mask,
                position_ids,
                past_key_values,
                output_attentions,
                use_cache,
            )
        else:
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
            )

        hidden_states = layer_outputs[0]

        if use_cache:
            next_decoder_cache = layer_outputs[2 if output_attentions else 1]

        if output_attentions:
            all_self_attns += (layer_outputs[1],)

    batch, seq_len, embed_dim = hidden_states.shape
    for start_idx in range(0, seq_len, 32000):
        end_idx = min(seq_len, start_idx + 32000)
        hidden_states[:, start_idx:end_idx, :] = self.norm(
            hidden_states[:, start_idx:end_idx, :]
        )

    # add hidden states from the last decoder layer
    if output_hidden_states:
        all_hidden_states += (hidden_states,)

    next_cache = None
    if use_cache:
        next_cache = (
            next_decoder_cache.to_legacy_cache()
            if use_legacy_cache
            else next_decoder_cache
        )
    if not return_dict:
        return tuple(
            v
            for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
            if v is not None
        )
    return BaseModelOutputWithPast(
        last_hidden_state=hidden_states,
        past_key_values=next_cache,
        hidden_states=all_hidden_states,
        attentions=all_self_attns,
    )


def forward_llama_for_causal_lm(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[List[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
    # assert labels is not None
    output_attentions = (
        output_attentions
        if output_attentions is not None
        else self.config.output_attentions
    )
    output_hidden_states = (
        output_hidden_states
        if output_hidden_states is not None
        else self.config.output_hidden_states
    )
    return_dict = (
        return_dict if return_dict is not None else self.config.use_return_dict
    )

    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
    outputs = self.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    torch.cuda.empty_cache()

    hidden_states = outputs[0]
    if labels is not None:
        loss_fct = CrossEntropyLoss(reduction="sum")
        valid_seq_len = input_ids.shape[-1] - 1
        valid_seq_len_slide_win = torch.sum(labels[:, 1:] >= 0).item()
        # print("valid_seq_len_slide_win", valid_seq_len)
        loss = 0.0

        for start_idx in range(0, valid_seq_len, 32000):
            end_idx = min(start_idx + 32000, valid_seq_len)
            shift_logits = self.lm_head(
                hidden_states[..., start_idx:end_idx, :]
            ).float()
            shift_labels = labels[..., start_idx + 1 : end_idx + 1].contiguous()
            # Flatten the tokens
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss += loss_fct(shift_logits, shift_labels)

        loss /= valid_seq_len_slide_win
        logits = None
    else:
        if self.config.to_dict().get("is_ppl", False):
            logits = self.lm_head(hidden_states)
        else:
            logits = self.lm_head(hidden_states[:, -1:]).float()
        loss = None

    return CausalLMOutputWithPast(
        loss=loss,
        logits=logits,
        past_key_values=outputs.past_key_values,
    )


def minference_patch(model, config):
    from transformers import LlamaForCausalLM

    if config.kv_cache_cpu:
        return minference_patch_kv_cache_cpu(model)
    if config.use_snapkv:
        return minference_patch_with_snapkv(model)

    Attention = model.model.layers[0].self_attn.__class__
    Model = model.model.__class__
    DecoderLayer = model.model.layers[0].__class__

    forward = minference_forward()

    def update_module(m):
        if isinstance(m, Attention):
            m.init_minference_parameters = init_minference_parameters.__get__(
                m, Attention
            )
            m.gather_last_q_vertical_slash_topk_v4 = (
                gather_last_q_vertical_slash_topk_v4.__get__(m, Attention)
            )
            m.forward = forward.__get__(m, Attention)
        if isinstance(m, DecoderLayer):
            m.forward = forward_llama_decoder_layer.__get__(m, DecoderLayer)

    model.apply(update_module)
    model.prepare_inputs_for_generation = hf_437_prepare_inputs_for_generation.__get__(
        model, model.__class__
    )
    model.model._use_sdpa = False

    model.model._prepare_decoder_attention_mask = (
        _prepare_decoder_attention_mask_inference.__get__(
            model.model, model.model.__class__
        )
    )
    model.model.forward = forward_llama_model.__get__(
        model.model, model.model.__class__
    )
    model.forward = forward_llama_for_causal_lm.__get__(model, model.__class__)
    model.has_patch = True

    print("Patched model for minference..")
    return model


def minference_patch_kv_cache_cpu(model):
    from transformers import LlamaForCausalLM

    transformers.cache_utils.DynamicCache.update = cpu_cache_update
    transformers.cache_utils.DynamicCache.get = cpu_cache_get

    Attention = model.model.layers[0].self_attn.__class__
    Model = model.model.__class__
    DecoderLayer = model.model.layers[0].__class__

    forward = minference_kv_cache_cpu_forward()

    def update_module(m):
        if isinstance(m, Attention):
            m.init_minference_parameters = init_minference_parameters.__get__(
                m, Attention
            )
            m.gather_last_q_vertical_slash_topk_v4 = (
                gather_last_q_vertical_slash_topk_v4.__get__(m, Attention)
            )
            m.forward = forward.__get__(m, Attention)
        if isinstance(m, DecoderLayer):
            m.forward = forward_llama_decoder_layer.__get__(m, DecoderLayer)

    model.apply(update_module)
    model.prepare_inputs_for_generation = hf_437_prepare_inputs_for_generation.__get__(
        model, model.__class__
    )
    model.model._use_sdpa = False

    model.model._prepare_decoder_attention_mask = (
        _prepare_decoder_attention_mask_inference.__get__(
            model.model, model.model.__class__
        )
    )
    model.model.forward = forward_llama_model.__get__(
        model.model, model.model.__class__
    )
    model.forward = forward_llama_for_causal_lm.__get__(model, model.__class__)

    print("Patched model for MInference load KV Cache to CPU.")
    return model


def minference_patch_with_snapkv(model):
    from transformers import LlamaForCausalLM

    Attention = model.model.layers[0].self_attn.__class__
    Model = model.model.__class__
    DecoderLayer = model.model.layers[0].__class__

    forward = minference_with_snapkv_forward()

    def update_module(m):
        if isinstance(m, Attention):
            m.init_minference_parameters = init_minference_parameters.__get__(
                m, Attention
            )
            m.gather_last_q_vertical_slash_topk_v4 = (
                gather_last_q_vertical_slash_topk_v4.__get__(m, Attention)
            )
            m.forward = forward.__get__(m, Attention)
        if isinstance(m, DecoderLayer):
            m.forward = forward_llama_decoder_layer.__get__(m, DecoderLayer)

    model.apply(update_module)
    model.prepare_inputs_for_generation = prepare_inputs_for_generation_snapkv.__get__(
        model, model.__class__
    )
    model.model._use_sdpa = False

    model.model._prepare_decoder_attention_mask = (
        _prepare_decoder_attention_mask_inference.__get__(
            model.model, model.model.__class__
        )
    )
    model.model.forward = forward_llama_model.__get__(
        model.model, model.model.__class__
    )
    model.forward = forward_llama_for_causal_lm.__get__(model, model.__class__)

    print("Patched model for minference with SanpKV..")
    return model


def llama_model_forward_vllm(
    self,
    input_ids: Optional[torch.Tensor],
    positions: torch.Tensor,
    kv_caches: List[torch.Tensor],
    attn_metadata,
    inputs_embeds: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    if inputs_embeds is not None:
        hidden_states = inputs_embeds
    else:
        hidden_states = self.get_input_embeddings(input_ids)
    residual = None
    for i in range(len(self.layers)):
        layer = self.layers[i]
        hidden_states, residual = layer(
            positions,
            hidden_states,
            kv_caches[i],
            attn_metadata,
            residual,
            layer_idx=i,
        )
    hidden_states, _ = self.norm(hidden_states, residual)
    return hidden_states


def llama_layer_forward_vllm(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata,
    residual: Optional[torch.Tensor],
    layer_idx: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    # Self Attention
    if residual is None:
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
    else:
        hidden_states, residual = self.input_layernorm(hidden_states, residual)
    hidden_states = self.self_attn(
        positions=positions,
        hidden_states=hidden_states,
        kv_cache=kv_cache,
        attn_metadata=attn_metadata,
        layer_idx=layer_idx,
    )

    # Fully Connected
    hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
    hidden_states = self.mlp(hidden_states)
    return hidden_states, residual


def llama_attn_forward_vllm(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
    kv_cache: torch.Tensor,
    attn_metadata,
    layer_idx: int,
) -> torch.Tensor:
    qkv, _ = self.qkv_proj(hidden_states)
    q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
    q, k = self.rotary_emb(positions, q, k)
    attn_output = self.attn(q, k, v, kv_cache, attn_metadata, self.kv_scale, layer_idx)
    output, _ = self.o_proj(attn_output)
    return output


def vllm_attn_forward(
    self,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    kv_cache: Optional[torch.Tensor],
    attn_metadata,
    kv_scale: float = 1.0,
    layer_idx: int = 0,
) -> torch.Tensor:
    return self.impl.forward(
        query, key, value, kv_cache, attn_metadata, kv_scale, layer_idx
    )


def minference_patch_vllm(
    llm,
    config_file,
):
    from vllm.attention import Attention
    from vllm.model_executor.models.llama import (
        LlamaAttention,
        LlamaDecoderLayer,
        LlamaForCausalLM,
        LlamaModel,
    )

    config = json.load(open(config_file))
    attn_forward = minference_vllm_forward(config)

    def update_module(m):
        if isinstance(m, Attention):
            m.forward = vllm_attn_forward.__get__(m, Attention)

            m = m.impl
            m_cls = m.__class__
            m.gather_last_q_vertical_slash_topk_vllm = (
                gather_last_q_vertical_slash_topk_vllm.__get__(m, m_cls)
            )
            m.forward = attn_forward.__get__(m, m_cls)
        if isinstance(m, LlamaDecoderLayer):
            m.forward = llama_layer_forward_vllm.__get__(m, LlamaDecoderLayer)
        if isinstance(m, LlamaModel):
            m.forward = llama_model_forward_vllm.__get__(m, LlamaModel)
        if isinstance(m, LlamaAttention):
            m.forward = llama_attn_forward_vllm.__get__(m, LlamaAttention)

    llm.llm_engine.model_executor.driver_worker.model_runner.model.apply(update_module)

    print("Patched model for minference with VLLM..")
    return llm


def patch_hf(
    model,
    attn_type: str = "inf_llm",
    attn_kwargs: dict = {},
    base=None,
    distance_scale=None,
    **kwargs,
):
    attn_kwargs.update(kwargs)
    # This approach lacks scalability and will be refactored.
    from transformers import LlamaForCausalLM, MistralForCausalLM, Qwen2ForCausalLM
    from transformers.models.llama.modeling_llama import (
        BaseModelOutputWithPast,
        LlamaAttention,
        LlamaModel,
    )
    from transformers.models.mistral.modeling_mistral import (
        MistralAttention,
        MistralModel,
    )
    from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2Model

    def model_forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask=None,
        position_ids=None,
        past_key_values=None,
        inputs_embeds=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        *args,
        **kwargs,
    ):
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
            )
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape
        elif inputs_embeds is not None:
            batch_size, seq_length, _ = inputs_embeds.shape
        else:
            raise ValueError(
                "You have to specify either decoder_input_ids or decoder_inputs_embeds"
            )

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)
            if hasattr(self, "config") and hasattr(self.config, "scale_emb"):
                inputs_embeds = inputs_embeds * self.config.scale_emb

        if use_cache:
            pkv = tuple()

        else:
            pkv = None

        hidden_states = inputs_embeds

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        for i, decoder_layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=self.position_bias,
                past_key_value=(
                    past_key_values[i] if past_key_values is not None else None
                ),
                output_attentions=output_attentions,
                use_cache=use_cache,
            )

            hidden_states = layer_outputs[0]

            if use_cache:
                _cache = layer_outputs[2 if output_attentions else 1]
                pkv = pkv + (_cache,)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        # hidden_states = self.norm(hidden_states)
        for start_idx in range(0, hidden_states.size(1), 32000):
            end_idx = min(hidden_states.size(1), start_idx + 32000)
            hidden_states[:, start_idx:end_idx, :] = self.norm(
                hidden_states[:, start_idx:end_idx, :]
            )

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, pkv, all_hidden_states, all_self_attns]
                if v is not None
            )
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=pkv,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    forward = huggingface_forward(ATTN_FORWRAD[attn_type](**attn_kwargs))

    if isinstance(model, LlamaForCausalLM):
        Attention = model.model.layers[0].self_attn.__class__
        Model = model.model.__class__
    elif isinstance(model, MistralForCausalLM):
        Attention = model.model.layers[0].self_attn.__class__
        Model = model.model.__class__
    elif isinstance(model, Qwen2ForCausalLM):
        Attention = model.model.layers[0].self_attn.__class__
        Model = model.model.__class__
    elif model.__class__.__name__ == "MiniCPMForCausalLM":
        Attention = model.model.layers[0].self_attn.__class__
        Model = model.model.__class__
    elif model.__class__.__name__ == "Phi3ForCausalLM":
        Attention = model.model.layers[0].self_attn.__class__
        Model = model.model.__class__
    else:
        raise ValueError("Only supports llama, mistral and qwen2 models.")

    hf_rope = model.model.layers[0].self_attn.rotary_emb
    base = base if base is not None else hf_rope.base
    distance_scale = distance_scale if distance_scale is not None else 1.0
    rope = RotaryEmbeddingESM(hf_rope.dim, base, distance_scale)
    model.model.position_bias = rope
    model.model.hf_position_bias = hf_rope

    def set_forward(m):
        if isinstance(m, Attention):
            m._old_forward = m.forward
            m.forward = forward.__get__(m, Attention)

    model.apply(set_forward)

    model._old_prepare_inputs_for_generation = model.prepare_inputs_for_generation
    model.prepare_inputs_for_generation = prepare_inputs_for_generation.__get__(
        model, model.__class__
    )
    model.model._old_forward = model.model.forward
    model.model.forward = model_forward.__get__(model.model, Model)

    if attn_type == "inf_llm":
        tokenizer = transformers.AutoTokenizer.from_pretrained(
            model.config._name_or_path
        )
        model = InfLLMGenerator(model, tokenizer)

    print("Patched model ...")
    return model


def fp8_cache_update(
    self,
    key_states: torch.Tensor,
    value_states: torch.Tensor,
    layer_idx: int,
    cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.

    Parameters:
        key_states (`torch.Tensor`):
            The new key states to cache.
        value_states (`torch.Tensor`):
            The new value states to cache.
        layer_idx (`int`):
            The index of the layer to cache the states for.
        cache_kwargs (`Dict[str, Any]`, `optional`):
            Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

    Return:
        A tuple containing the updated key and value states.
    """
    # Update the number of seen tokens
    if layer_idx == 0:
        self.seen_tokens += key_states.shape[-2]

    # Update the cache
    if len(self.key_cache) <= layer_idx:
        self.key_cache.append(key_states.to(torch.float8_e5m2))
        self.value_cache.append(value_states.to(torch.float8_e5m2))
    else:
        self.key_cache[layer_idx] = torch.cat(
            [self.key_cache[layer_idx], key_states.to(torch.float8_e5m2)], dim=-2
        )
        self.value_cache[layer_idx] = torch.cat(
            [self.value_cache[layer_idx], value_states.to(torch.float8_e5m2)], dim=-2
        )

    return self.key_cache[layer_idx].to(key_states.dtype), self.value_cache[
        layer_idx
    ].to(key_states.dtype)


def cpu_cache_update(
    self,
    key_states: torch.Tensor,
    value_states: torch.Tensor,
    layer_idx: int,
    cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    if layer_idx == 0:
        if "_seen_tokens" in self.__dict__:
            self._seen_tokens += key_states.shape[-2]
        else:
            self.seen_tokens += key_states.shape[-2]

    # Update the cache
    if len(self.key_cache) <= layer_idx:
        self.key_cache.append(key_states.cpu())
        self.value_cache.append(value_states.cpu())
    else:
        self.key_cache[layer_idx] = torch.cat(
            [self.key_cache[layer_idx], key_states.cpu()], dim=-2
        )
        self.value_cache[layer_idx] = torch.cat(
            [self.value_cache[layer_idx], value_states.cpu()], dim=-2
        )


def cpu_cache_get(
    self,
    key_states: torch.Tensor,
    value_states: torch.Tensor,
    layer_idx: int,
    head_idx: int,
    cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    if layer_idx == 0:
        if "_seen_tokens" in self.__dict__:
            self._seen_tokens += key_states.shape[-2]
        else:
            self.seen_tokens += key_states.shape[-2]

    # Update the cache
    if len(self.key_cache) <= layer_idx:
        return key_states, value_states
    else:
        key_states = torch.cat(
            [self.key_cache[layer_idx][:, head_idx : head_idx + 1].cuda(), key_states],
            dim=-2,
        )
        value_states = torch.cat(
            [
                self.value_cache[layer_idx][:, head_idx : head_idx + 1].cuda(),
                value_states,
            ],
            dim=-2,
        )
        return key_states, value_states