import json
import torch
import torch.nn as nn
from torch.nn import functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'


# one head of self-attention using scaled-dot product attention
class Head(nn.Module):
    def __init__(self, n_embed, head_size, context_size, dropout=0.1):
        super().__init__()
        
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(context_size, context_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)

        tril = torch.tril(torch.ones(T, T, device=device))
        wei = q @ k.transpose(-2, -1) * (C**-0.5)
        wei = wei.masked_fill(tril == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        out = wei @ v
        return out


class MultiHeadAttention(nn.Module):
    def __init__(self, n_embed, num_heads, context_size, head_size, dropout):
        super().__init__()

        self.heads = nn.ModuleList([
            Head(n_embed, head_size, context_size)
            for _ in range(num_heads)
        ])
        self.projection = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.projection(out)
        return self.dropout(out)


# simple feed forward layer
class FeedForward(nn.Module):
    def __init__(self, n_embeds, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embeds, 4 * n_embeds),
            nn.ReLU(),
            # projection layer
            nn.Linear(4 * n_embeds, n_embeds),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


# Transformer block
class Block(nn.Module):
    def __init__(self, n_embeds, n_head, context_size, dropout):
        super().__init__()
        head_size = n_embeds // n_head
        self.sa = MultiHeadAttention(n_embeds, n_head, context_size, head_size, dropout)
        self.ffwd = FeedForward(n_embeds, dropout)
        self.ln1 = nn.LayerNorm(n_embeds)
        self.ln2 = nn.LayerNorm(n_embeds)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


# simple bigram model
class DecoderTransformer(nn.Module):
    def __init__(self, vocab_size, n_embed, context_size, n_layer, n_head, dropout):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(context_size, n_embed)
        self.blocks = nn.Sequential(
            *[Block(
                n_embeds=n_embed, 
                n_head=n_head, 
                context_size=context_size, 
                dropout=dropout
            ) for _ in range(n_layer)]
        )
        self.ln_f = nn.LayerNorm(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        # idx and targets of size (B,T)
        token_embeds = self.token_embedding_table(idx) # yields (B, T, C)
        pos_embeds = self.position_embedding_table(torch.arange(T, device=device))
        x = token_embeds + pos_embeds
        x = self.ln_f(self.blocks(x))
        logits = self.lm_head(x)

        if targets is None:
            return logits, None

        # reshape elements
        B, T, C = logits.shape
        logits = logits.view(B*T,C)
        targets = targets.view(B*T)
        # compute loss (CE)
        loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens=50, context_size=None, temperature=1.0):
        if context_size is None:
            context_size = int(self.position_embedding_table.weight.shape[0])
            print(context_size)

        for _ in range(max_new_tokens):
            idx_cond = idx[:, -context_size:]
            logits, loss = self(idx_cond)
            logits = logits[:,-1,:] / temperature
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1)
        return idx


class Tokenizer:
    def __init__(self, vocab):
        self.vocab = vocab
        self.stoi = {ch: idx for idx, ch in enumerate(vocab)}
        self.itos = {idx: ch for idx, ch in enumerate(vocab)}

    def encode(self, s):
        return [self.stoi[c] for c in s]

    def decode(self, i):
        return ''.join([self.itos[x] for x in i])

    @classmethod
    def from_pretrained(cls, path):
        with open(path, 'r') as f:
            vocab = json.load(f)
        return cls(vocab)

    def save_pretrained(self, path):
        with open(path, 'w') as f:
            json.dump(self.vocab, f)