import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

from typing import Optional,Tuple
import math
import logging

logger = logging.getLogger(__name__)


rwkv_emb_scale = 0.4 # try 0.4 for char-level english. try 1.0 for chinese.
rwkv_layer_decay = 1.0 # decay weights in higher layers. try 0.5 ~ 1.0.

class AttentionConfig:
  def __init__(self, ctx_len=100, **kwargs):
    self.ctx_len = ctx_len
    for k,v in kwargs.items():
        setattr(self, k, v)


########################################################################################################
# MHA_rotary: Multi-head Attention + Rotary Encoding + GeGLU FFN
########################################################################################################

class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, base=10000):
        super().__init__()
        inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.seq_len_cached = None
        self.cos_cached = None
        self.sin_cached = None

    def forward(self, x, seq_len=None):
        if seq_len != self.seq_len_cached:
            self.seq_len_cached = seq_len
            t = torch.arange(seq_len, device=x.device)
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()
            self.sin_cached = emb.sin()
        return torch.stack([self.cos_cached, self.sin_cached])

class ContinuousRotaryEmbedding(torch.nn.Module):
    '''Continuous rotary position embedding'''
    def __init__(self, dim, sequence_scale):
        super().__init__()
        base=10000
        self.sequence_scale = sequence_scale
        self.register_buffer('inv_freq', 1. / (base ** (torch.arange(0, dim, 2))))
    
    def forward(self, t):
        t = (t + 0.5)* self.sequence_scale 
        freqs = torch.einsum('ij,k->ijk', t, self.inv_freq) # freqs: [B, L, dim//2]
        emb = torch.cat((freqs, freqs), dim=-1).unsqueeze(1) # emb: [B, 1, L, dim], 1 for broadcast in head_num dim
        return torch.stack([emb.cos(), emb.sin()])
    
def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1] // 2], x[..., x.shape[-1] // 2:]
    return torch.cat((-x2, x1), -1)

@torch.jit.script
def apply_rotary_pos_emb(q, k, cos, sin):
    cos, sin = cos[...,:q.shape[2],:], sin[...,:q.shape[2],:]
    return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)

class MHA_rotary(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.collect_attention_map = False
        self.attention_map = None
        assert args.encoder_dim % args.num_heads == 0
        self.num_heads = args.num_heads
        self.head_size = args.encoder_dim // args.num_heads

        if args.timeshift:
            self.time_shift = nn.ZeroPad2d((0,0,1,0))

        self.query = nn.Linear(args.encoder_dim, args.encoder_dim)
        self.key = nn.Linear(args.encoder_dim, args.encoder_dim)
        self.value = nn.Linear(args.encoder_dim, args.encoder_dim)

        # self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
        
        self.rotary_ndims = int(self.head_size * 0.5)
        
        self.rotary_emb = RotaryEmbedding(self.rotary_ndims)

        self.output = nn.Linear(args.encoder_dim, args.encoder_dim)

    def forward(self, x, RoPE, key_padding_mask=None):
        B, T, C = x.size()

        if hasattr(self, 'time_shift'):
            x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)

        q = self.query(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2)       # (B, T, C) -> (B, nh, T, hs)
        k = self.key(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2)         # (B, T, C) -> (B, nh, T, hs)
        v = self.value(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2)       # (B, T, C) -> (B, nh, T, hs)

        q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
        k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
        
        # cos, sin = self.rotary_emb(q, seq_len=T)
        cos, sin = RoPE
        q, k = apply_rotary_pos_emb(q, k, cos, sin)                                     # rotary encoding
        q = torch.cat((q, query_pass), dim=-1)
        k = torch.cat((k, key_pass), dim=-1)  
        
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))                 # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask[:, None, None, :]           # (B, T) -> (B, 1, 1, T)
            att = att.masked_fill(key_padding_mask == 0, float('-inf'))
        att = F.softmax(att, dim = -1)                                                  # softmax

        x = att @ v                                                                     # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
        x = x.transpose(1, 2).contiguous().view(B, T, -1)                               # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)

        x = self.output(x)

        if self.collect_attention_map:
            self.attention_map = att
        
        return x

class MHA_decoder(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.collect_attention_map = False
        self.attention_map = None
        assert args.encoder_dim % args.num_heads == 0
        self.num_heads = args.num_heads
        self.head_size = args.decoder_dim // args.num_heads

        if args.timeshift:
            self.time_shift = nn.ZeroPad2d((0,0,1,0))

        self.query = nn.Linear(args.decoder_dim, args.decoder_dim)
        self.key = nn.Linear(args.decoder_dim, args.decoder_dim)
        self.value = nn.Linear(args.decoder_dim, args.decoder_dim)

        # self.register_buffer("mask", torch.tril(torch.ones(config.ctx_len, config.ctx_len)))
        
        self.rotary_ndims = int(self.head_size * 0.5)
        
        self.rotary_emb = RotaryEmbedding(self.rotary_ndims)

        self.output = nn.Linear(args.decoder_dim, args.decoder_dim)

    def forward(self, x, memory,RoPE, key_padding_mask=None):
        B, T, C = x.size()
        _, L, M = memory.size()

        # print("x size: ", x.size(), 'memory size: ', memory.size())
        # print('B, T, C: ', B, T, C, 'L: ', L)

        q = self.query(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2)       # (B, T, C) -> (B, nh, T, hs)
        k = self.key(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2)         # (B, T, C) -> (B, nh, T, hs)
        v = self.value(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2)       # (B, T, C) -> (B, nh, T, hs)

        q, query_pass = q[..., :self.rotary_ndims], q[..., self.rotary_ndims:]
        k, key_pass = k[..., :self.rotary_ndims], k[..., self.rotary_ndims:]
        
        # cos, sin = self.rotary_emb(q, seq_len=T)
        cos, sin = RoPE
        q, k = apply_rotary_pos_emb(q, k, cos, sin)                                     # rotary encoding
        q = torch.cat((q, query_pass), dim=-1)
        k = torch.cat((k, key_pass), dim=-1)  
        
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))                 # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask[:, None, None, :]           # (B, T) -> (B, 1, 1, T)
            att = att.masked_fill(key_padding_mask == 0, float('-inf'))
        att = F.softmax(att, dim = -1)                                                  # softmax

        x = att @ v  
        # print("after attention vals: ", x.shape)                                                                   # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
        x = x.transpose(1, 2).contiguous().view(B, T, -1)                               # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)

        # x = self.output(x)

        # print("after linear: ", x.shape)                                                                   # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)


        # cross attention:
        q = self.query(x).view(B, T, self.num_heads, self.head_size).transpose(1, 2)       # (B, T, C) -> (B, nh, T, hs)
        k = self.key(memory).view(B, L, self.num_heads, self.head_size).transpose(1, 2)         # (B, T, C) -> (B, nh, T, hs)
        v = self.value(memory).view(B, L, self.num_heads, self.head_size).transpose(1, 2)       # (B, T, C) -> (B, nh, T, hs)
        
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))                 # self-attention: (B, nh, T, hs) * (B, nh, hs, T) -> (B, nh, T, T)
        # print("att size: ", att.size())
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask[:, None, None, :]           # (B, T) -> (B, 1, 1, T)
            att = att.masked_fill(key_padding_mask == 0, float('-inf'))
        att = F.softmax(att, dim = -1)                                                  # softmax

        x = att @ v                                                                     # (B, nh, T, T) * (B, nh, T, hs) -> (B, nh, T, hs)
        # print("x deocder size: ", x.size())
        x = x.transpose(1, 2).contiguous().view(B, T, -1)                               # (B, nh, T, hs) -> (B, T, nh, hs) -> (B, T, C)
        # print("x deocder size transposed: ", x.size())
        x = self.output(x)

        if self.collect_attention_map:
            self.attention_map = att

        return x

    class GeGLU(torch.nn.Module):
        def __init__(self, config, layer_id, time_shift = False):
            super().__init__()
            self.layer_id = layer_id

            if time_shift:
                self.time_shift = nn.ZeroPad2d((0,0,1,0))

            hidden_sz = 3 * config.n_ffn
            self.key = nn.Linear(config.n_embd, hidden_sz)
            self.value = nn.Linear(config.n_embd, hidden_sz)
            self.weight = nn.Linear(hidden_sz, config.n_embd)

        def forward(self, x):
            B, T, C = x.size()
            if hasattr(self, 'time_shift'):
                x = torch.cat([self.time_shift(x)[:, :-1, :C//2], x[:, :, C//2:]], dim = -1)
            
            k = self.key(x)
            v = self.value(x)        
            y = self.weight(F.gelu(k) * v)
            return y