new-nlp-hw3-llama3 / llama.py
Mortie1's picture
Upload MyLLaMa
03783e6 verified
import torch
from torch import Tensor, nn
from torch.nn import Sequential
from torch.utils.checkpoint import checkpoint, checkpoint_sequential
from xformers.components.attention.utils import maybe_merge_masks
from xformers.components import MultiHeadDispatch
from xformers.components.attention import ScaledDotProduct
from transformers import AutoTokenizer
class RotaryEmbedding(nn.Module):
def __init__(
self,
dim_per_head: int,
max_seq_len: int = 4096,
interpolation_ratio: float | None = 0.25,
device=None,
dtype=None,
):
super().__init__()
self.dim_per_head = dim_per_head
self.max_seq_len = max_seq_len
freqs = 1.0 / (
10000
** (
torch.arange(0, dim_per_head, 2, device=device, dtype=dtype).float() / 6
)
)
freqs = torch.repeat_interleave(freqs, 2)
r = (
freqs
* torch.arange(max_seq_len, device=device, dtype=dtype).float()[:, None]
)
if interpolation_ratio is not None:
r = r * interpolation_ratio
r1 = r.cos()
self.register_buffer("r1", r1)
r2 = r.sin()
self.register_buffer("r2", r2)
aranged = torch.arange(dim_per_head, device=device, dtype=dtype)
mask1 = torch.where(
aranged % 2 == 1,
aranged - 1,
aranged + 1,
).float()
self.register_buffer("mask1", mask1)
mask2 = torch.where(aranged % 2 == 0, -1, 1).float()
self.register_buffer("mask2", mask2)
def forward(self, x: Tensor):
"""
Args:
x (Tensor): input tensor. shape: (bs, seq_len, n_heads, dim_per_head)
Returns:
Tensor: input tensor with rotary embeddings. shape: (bs, seq_len, n_heads, dim_per_head)
"""
assert (
x.ndim == 4
), "input must have 4 dimensions: (bs, n_heads, seq_len, dim_per_head)"
assert x.shape[3] % 2 == 0, "dim_per_head must be divisible by 2"
x = x.transpose(1, 2)
return (
x * self.r1[None, : x.shape[1], None, :]
+ x[
:,
:,
:,
self.mask1.int(),
]
* self.mask2.int()
* self.r2[None, : x.shape[1], None, :]
).transpose(1, 2)
def extra_repr(self) -> str:
return f"dim_per_head={self.dim_per_head}, max_seq_len={self.max_seq_len}"
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-9):
super().__init__()
self.dim = dim
self.trainable = nn.Parameter(data=torch.ones((dim,)), requires_grad=True)
self.eps = eps
def forward(self, x: Tensor):
"""
Args:
x (Tensor): input tensor. shape: (bs, seq_len, embed_dim)
Returns:
Tensor: input tensor with rotary embeddings. shape: (bs, seq_len, embed_dim)
"""
assert x.ndim == 3, "input must have 3 dimensions: (bs, seq_len, embed_dim)"
return (
x
/ torch.sqrt_(torch.mean(torch.square(x), dim=-1) + self.eps)[:, :, None]
* self.trainable
)
def extra_repr(self) -> str:
return f"dim={self.dim}, eps={self.eps}"
class SiLU(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x: Tensor):
"""
Args:
x (Tensor): input
"""
return x * x.sigmoid()
class SwiGLU(nn.Module):
def __init__(self, dim: int) -> None:
super().__init__()
self.linear_inp1 = nn.Linear(dim, (8 * dim) // 3, bias=False)
self.linear_inp2 = nn.Linear(dim, (8 * dim) // 3, bias=False)
self.linear_out = nn.Linear((8 * dim) // 3, dim, bias=False)
self.silu = SiLU()
# nn.init.xavier_uniform_(self.linear_inp1.weight)
# nn.init.xavier_uniform_(self.linear_inp2.weight)
# nn.init.xavier_uniform_(self.linear_out.weight)
def forward(self, x: Tensor):
"""
Args:
x (Tensor): input tensor
"""
return self.linear_out(self.silu(self.linear_inp1(x)) * self.linear_inp2(x))
class MistralTokenizer(nn.Module):
def __init__(self, max_length=1024, *args, **kwargs):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(
"mistralai/Mistral-7B-v0.1", *args, **kwargs
)
self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
self.special_tokens_ids = {
token: id
for token, id in zip(
self.tokenizer.special_tokens_map.keys(), self.tokenizer.all_special_ids
)
}
self.max_length = max_length
self.pad_token_id = self.tokenizer.pad_token_id
def forward(self, text):
return self.tokenizer(
text,
return_tensors="pt",
return_attention_mask=False,
max_length=self.max_length,
truncation=True,
padding=True,
padding_side="right",
)
def convert_ids_to_tokens(self, ids):
return self.tokenizer.convert_ids_to_tokens(ids)
def decode(self, x):
return self.tokenizer.batch_decode(x)
def __len__(self):
return len(self.tokenizer)
class MultiHeadAttention(nn.Module):
def __init__(
self,
emb_size: int,
n_heads: int,
dropout: float = 0.0,
use_rotary_embeddings: bool = False,
bias_qkv: bool = False,
bias_out: bool = False,
):
super().__init__()
self.emb_size = emb_size
self.n_heads = n_heads
assert (
self.emb_size % n_heads == 0
), "Embedding size needs to be divisible by heads"
self.head_dim = emb_size // n_heads
self.use_rotary_embeddings = use_rotary_embeddings
if self.use_rotary_embeddings:
self.rotary_embed = RotaryEmbedding(self.head_dim)
self.qkv = nn.Linear(emb_size, emb_size * 3, bias=bias_qkv)
self.dropout = nn.Dropout(dropout)
self.out = nn.Linear(emb_size, emb_size, bias=bias_out)
self.scaling = self.head_dim**-0.5
def forward(self, x: Tensor, att_mask: Tensor = None):
qkv = self.qkv(x).chunk(3, dim=-1)
q, k, v = map(
lambda t: t.reshape(x.shape[0], -1, self.n_heads, self.head_dim).transpose(
1, 2
),
qkv,
) # [batch_size, n_heads, seq_len, head_dim]
if self.use_rotary_embeddings:
q, k = self.rotary_embed(q), self.rotary_embed(k)
dots = (
torch.matmul(q, k.transpose(-1, -2)) * self.scaling
) # [batch_size, n_heads, seq_len, seq_len]
if att_mask is not None:
dots = dots + att_mask
attn = self.dropout(torch.softmax(dots, dim=-1))
out = (
torch.matmul(attn, v).transpose(1, 2).reshape(x.shape[0], -1, self.emb_size)
)
out = self.out(out)
return out
class LLaMADecoderLayer(nn.Module):
def __init__(
self,
emb_size: int,
n_heads: int,
dropout: float,
) -> None:
super().__init__()
self.emb_size = emb_size
self.multihead_attn = MultiHeadDispatch(
dim_model=emb_size,
num_heads=n_heads,
attention=ScaledDotProduct(
dropout=dropout,
),
bias=(False, False, False, False),
use_rotary_embeddings=True,
)
self.rmsnorm1 = nn.RMSNorm(emb_size, eps=1e-9)
self.rmsnorm2 = nn.RMSNorm(emb_size, eps=1e-9)
self.swiglu = SwiGLU(emb_size)
self.n_heads = n_heads
def forward(self, in_tuple) -> Tensor:
"""
Args:
in_tuple (tuple[Tensor, Tensor, Tensor]): tuple, containing 3 tensors:
x (Tensor): input tensor (bs, seq_len, dim)
attn_mask (Tensor): attention mask (seq_len, seq_len)
padding_mask (Tensor): padding mask (bs, seq_len)
Returns:
Tensor: output tensor
"""
assert len(in_tuple) == 2, "input tuple must have 2 elements"
x, mask = in_tuple
x = self.multihead_attn(self.rmsnorm1(x), att_mask=mask) + x
return self.swiglu(self.rmsnorm2(x)) + x, mask
class CustomAttentionLLaMaDecoder(LLaMADecoderLayer):
def __init__(
self,
emb_size: int,
n_heads: int,
dropout: float,
) -> None:
super().__init__(emb_size, n_heads, dropout)
self.multihead_attn = MultiHeadAttention(
emb_size=emb_size,
n_heads=n_heads,
bias_qkv=False,
bias_out=False,
use_rotary_embeddings=True,
dropout=dropout,
)
self.rmsnorm1 = RMSNorm(emb_size, eps=1e-9)
self.rmsnorm2 = RMSNorm(emb_size, eps=1e-9)
class LLaMaBase(nn.Module):
def __init__(
self,
embed_dim: int = 512,
n_layers: int = 2,
n_heads: int = 8,
dropout: int = 0.0,
n_chckpnt_segments: int = 1,
tokenizer=MistralTokenizer(),
**kwargs,
):
"""
Args:
n_feats (int): number of input features.
n_class (int): number of classes.
fc_hidden (int): number of hidden features.
"""
super().__init__()
self.tokenizer = tokenizer
self.vocab_len = len(tokenizer)
self.n_heads = n_heads
self.dropout = dropout
self.n_layers = n_layers
self.embed_dim = embed_dim
self.n_segments = n_chckpnt_segments
self.embed = nn.Embedding(
self.vocab_len, embed_dim, padding_idx=self.tokenizer.pad_token_id
)
self.head = nn.Linear(embed_dim, self.vocab_len, bias=False)
def forward(self, src: Tensor, attn_mask: Tensor, pad_mask: Tensor, **batch):
"""
Model forward method.
Args:
tokenized (Tensor): input text. shape: (batch_size, seq_len)
Returns:
output (dict): output dict containing logits.
"""
raise NotImplementedError
def __str__(self):
"""
Model prints with the number of parameters.
"""
all_parameters = sum([p.numel() for p in self.parameters()])
trainable_parameters = sum(
[p.numel() for p in self.parameters() if p.requires_grad]
)
embedding_parameters = sum([p.numel() for p in self.embed.parameters()])
result_info = super().__str__()
result_info = result_info + f"\nAll parameters: {all_parameters}"
result_info = result_info + f"\nTrainable parameters: {trainable_parameters}"
result_info = (
result_info
+ f"\nWithout embedding: {trainable_parameters - embedding_parameters}"
)
return result_info
class CustomAttentionLLaMa(LLaMaBase):
def __init__(
self,
embed_dim: int = 512,
n_layers: int = 2,
n_heads: int = 8,
dropout: int = 0.0,
n_chckpnt_segments: int = 1,
tokenizer=MistralTokenizer(),
**kwargs,
):
"""
Args:
n_feats (int): number of input features.
n_class (int): number of classes.
fc_hidden (int): number of hidden features.
"""
super().__init__(
embed_dim,
n_layers,
n_heads,
dropout,
n_chckpnt_segments,
tokenizer,
)
self.decoders = nn.Sequential(
*[
CustomAttentionLLaMaDecoder(
emb_size=embed_dim, n_heads=self.n_heads, dropout=dropout
)
for _ in range(n_layers)
]
)
self.rmsnorm = RMSNorm(embed_dim, eps=1e-9)
def forward(self, src: Tensor, attn_mask: Tensor, pad_mask: Tensor, **batch):
"""
Model forward method.
Args:
tokenized (Tensor): input text. shape: (batch_size, seq_len)
Returns:
output (dict): output dict containing logits.
"""
x = self.embed(src) # embeds shape: [batch_size, seq_len, embed_dim]
sizes = x.shape
mask = maybe_merge_masks(
attn_mask, pad_mask, sizes[0], sizes[1], self.n_heads
).view(x.shape[0], self.n_heads, sizes[1], sizes[1])
x, _ = checkpoint_sequential(self.decoders, self.n_segments, input=(x, mask))
# for decoder in self.decoders:
# x, _, _ = decoder((x, attn_mask, pad_mask))
logits = self.head(self.rmsnorm(x))
return {
"logits": logits.permute(0, 2, 1)
} # logits shape: [batch_size, vocab_len, seq_len]