import torch
import torch.nn as nn
import torch.nn.functional as F
from models.norm import RMSNorm
from models.rope import precompute_freqs_cis, apply_rotary_emb
import bitsandbytes as bnb
import math


class NormalLinear(nn.Linear):
    def reset_parameters(self) -> None:
        pass


class BnbInt8Linear(bnb.nn.Linear8bitLt):
    def __init__(self, *args, **kwargs):
        super().__init__(has_fp16_weights=False, threshold=6.0, *args, **kwargs)

    def reset_parameters(self) -> None:
        pass


def get_linear_layer(use_int8):
    if use_int8:
        return BnbInt8Linear
    return NormalLinear


class WordEmbedding(nn.Module):
    def __init__(self, args):
        super(WordEmbedding, self).__init__()
        self.embedding = nn.Embedding(args.vocab_size, args.emb_size)

    def forward(self, src):
        emb = self.embedding(src)
        return emb


class MultiHeadedAttention(nn.Module):
    def __init__(self, args, hidden_size, heads_num, attention_head_size, has_bias=True, use_int8=True):
        super(MultiHeadedAttention, self).__init__()
        self.heads_num = heads_num

        self.per_head_size = attention_head_size
        self.inner_hidden_size = heads_num * attention_head_size

        Linear = get_linear_layer(use_int8)
        self.linear_layers = nn.ModuleList(
            [Linear(hidden_size, self.inner_hidden_size, bias=has_bias) for _ in range(3)]
        )

        self.final_linear = Linear(self.inner_hidden_size, hidden_size, bias=has_bias)

        # add cache to reduce compute source.
        self.cache_k = torch.zeros(
            (args.batch_size, args.seq_length, self.heads_num, self.per_head_size)
        )
        self.cache_v = torch.zeros(
            (args.batch_size, args.seq_length, self.heads_num, self.per_head_size)
        )

    def forward(self, key, value, query, start_pos, continue_exsample, mask, freqs_cis):
        batch_size, seq_length, _ = query.size()
        heads_num = self.heads_num
        per_head_size = self.per_head_size
        query, key, value = [l(x).view(batch_size, -1, heads_num, per_head_size) \
                             for l, x in zip(self.linear_layers, (query, key, value))]
        query, key = apply_rotary_emb(query, key, freqs_cis=freqs_cis)
        if self.cache_k.device != key.device:
            self.cache_k = self.cache_k.to(key)
        if self.cache_v.device != value.device:
            self.cache_v = self.cache_v.to(value)

        self.cache_k[continue_exsample, start_pos: start_pos + seq_length] = key
        self.cache_v[continue_exsample, start_pos: start_pos + seq_length] = value

        key = self.cache_k[continue_exsample, : start_pos + seq_length]
        value = self.cache_v[continue_exsample, : start_pos + seq_length]

        query, key, value = [x.transpose(1, 2) for x in (query, key, value)]

        scores = torch.matmul(query, key.transpose(-2, -1))
        scores = scores / math.sqrt(float(per_head_size))
        if mask is not None:
            scores += mask
        # probs = nn.Softmax(dim=-1)(scores)
        probs = F.softmax(scores.float(), dim=-1).type_as(query)
        output = torch.matmul(probs, value).transpose(1, 2).\
            contiguous().view(batch_size, seq_length, -1)
        return self.final_linear(output)


class GatedFeedForward(nn.Module):
    def __init__(self, hidden_size, feedforward_size, has_bias=True, use_int8=True):
        super(GatedFeedForward, self).__init__()
        Linear = get_linear_layer(use_int8)
        self.linear_gate = Linear(hidden_size, feedforward_size, bias=has_bias)
        self.linear_1 = Linear(hidden_size, feedforward_size, bias=has_bias)
        self.linear_2 = Linear(feedforward_size, hidden_size, bias=has_bias)
        self.act = F.silu

    def forward(self, x):
        # gate = self.act(self.linear_gate(x))
        gate = self.act(self.linear_gate(x)).type_as(x)
        inter_linear = self.linear_1(x)
        inter = gate * inter_linear
        output = self.linear_2(inter)
        return output


class TransformerLayer(nn.Module):
    def __init__(self, args):
        super(TransformerLayer, self).__init__()

        if hasattr(args, "attention_head_size"):
            attention_head_size = args.attention_head_size
        else:
            attention_head_size = args.hidden_size // args.heads_num

        has_bias = bool(1 - args.remove_transformer_bias)
        # Multi-head Attention
        self.self_attn = MultiHeadedAttention(
            args, args.hidden_size, args.heads_num, attention_head_size, has_bias=has_bias,
            use_int8=args.use_int8
        )

        # FFN
        self.feed_forward = GatedFeedForward(
            args.hidden_size, args.feedforward_size, has_bias, use_int8=args.use_int8
        )

        self.layer_norm_1 = RMSNorm(args.hidden_size)
        self.layer_norm_2 = RMSNorm(args.hidden_size)

    def forward(self, hidden, start_pos, continue_exsample, mask, freqs_cis=None):
        inter = self.layer_norm_1(hidden)
        inter = self.self_attn(inter, inter, inter, start_pos, continue_exsample, mask, freqs_cis)
        hidden = hidden + inter
        output = self.layer_norm_2(hidden)
        output = self.feed_forward(output) + hidden
        return output


class TransformerEncoder(nn.Module):
    def __init__(self, args):
        super(TransformerEncoder, self).__init__()
        self.mask = args.mask
        self.layers_num = args.layers_num

        self.transformer = nn.ModuleList(
            [TransformerLayer(args) for _ in range(self.layers_num)]
        )

        self.layer_norm = RMSNorm(args.hidden_size)
        self.freqs_cis = precompute_freqs_cis(args.hidden_size // args.heads_num, args.max_seq_length * 2)

    def forward(self, emb, start_pos, continue_exsample):
        batch_size, seq_length, _ = emb.size()
        mask = None
        if seq_length > 1:
            mask = torch.ones(seq_length, seq_length, device=emb.device)
            mask = torch.tril(mask)
            mask = (1.0 - mask) * -10000
            mask = mask.repeat(batch_size, 1, 1, 1)

        hidden = emb
        freqs_cis = self.freqs_cis[start_pos: start_pos + seq_length].to(hidden.device)

        for i in range(self.layers_num):
            hidden = self.transformer[i](hidden, start_pos, continue_exsample, mask, freqs_cis=freqs_cis)
        return self.layer_norm(hidden)


class LmOutput(nn.Module):
    def __init__(self, args):
        super(LmOutput, self).__init__()
        # update: lm output not use int8
        Linear = get_linear_layer(False)
        self.lm = Linear(args.hidden_size, args.vocab_size, bias=False)

    def forward(self, x):
        return self.lm(x[:, -1, :])


class LLaMa(nn.Module):
    def __init__(self, args):
        super(LLaMa, self).__init__()
        self.embedding = WordEmbedding(args)
        self.encoder = TransformerEncoder(args)
        self.target = LmOutput(args)

    #@torch.inference_mode()
    def forward(self, src, start_pos, continue_exsample):
        emb = self.embedding(src)
        output = self.encoder(emb, start_pos, continue_exsample)
        output = self.target(output)
        return output