|
import torch |
|
from torch import Tensor, nn |
|
from torch.nn import Sequential |
|
from torch.utils.checkpoint import checkpoint, checkpoint_sequential |
|
from xformers.components.attention.utils import maybe_merge_masks |
|
from xformers.components import MultiHeadDispatch |
|
from xformers.components.attention import ScaledDotProduct |
|
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
def __init__( |
|
self, |
|
dim_per_head: int, |
|
max_seq_len: int = 4096, |
|
interpolation_ratio: float | None = 0.25, |
|
device=None, |
|
dtype=None, |
|
): |
|
super().__init__() |
|
|
|
self.dim_per_head = dim_per_head |
|
self.max_seq_len = max_seq_len |
|
freqs = 1.0 / ( |
|
10000 |
|
** ( |
|
torch.arange(0, dim_per_head, 2, device=device, dtype=dtype).float() / 6 |
|
) |
|
) |
|
freqs = torch.repeat_interleave(freqs, 2) |
|
|
|
r = ( |
|
freqs |
|
* torch.arange(max_seq_len, device=device, dtype=dtype).float()[:, None] |
|
) |
|
if interpolation_ratio is not None: |
|
r = r * interpolation_ratio |
|
|
|
r1 = r.cos() |
|
self.register_buffer("r1", r1) |
|
|
|
r2 = r.sin() |
|
self.register_buffer("r2", r2) |
|
|
|
aranged = torch.arange(dim_per_head, device=device, dtype=dtype) |
|
|
|
mask1 = torch.where( |
|
aranged % 2 == 1, |
|
aranged - 1, |
|
aranged + 1, |
|
).float() |
|
self.register_buffer("mask1", mask1) |
|
|
|
mask2 = torch.where(aranged % 2 == 0, -1, 1).float() |
|
self.register_buffer("mask2", mask2) |
|
|
|
def forward(self, x: Tensor): |
|
""" |
|
Args: |
|
x (Tensor): input tensor. shape: (bs, seq_len, n_heads, dim_per_head) |
|
|
|
Returns: |
|
Tensor: input tensor with rotary embeddings. shape: (bs, seq_len, n_heads, dim_per_head) |
|
""" |
|
|
|
assert ( |
|
x.ndim == 4 |
|
), "input must have 4 dimensions: (bs, n_heads, seq_len, dim_per_head)" |
|
assert x.shape[3] % 2 == 0, "dim_per_head must be divisible by 2" |
|
|
|
x = x.transpose(1, 2) |
|
|
|
return ( |
|
x * self.r1[None, : x.shape[1], None, :] |
|
+ x[ |
|
:, |
|
:, |
|
:, |
|
self.mask1.int(), |
|
] |
|
* self.mask2.int() |
|
* self.r2[None, : x.shape[1], None, :] |
|
).transpose(1, 2) |
|
|
|
def extra_repr(self) -> str: |
|
return f"dim_per_head={self.dim_per_head}, max_seq_len={self.max_seq_len}" |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, dim: int, eps: float = 1e-9): |
|
super().__init__() |
|
|
|
self.dim = dim |
|
self.trainable = nn.Parameter(data=torch.ones((dim,)), requires_grad=True) |
|
self.eps = eps |
|
|
|
def forward(self, x: Tensor): |
|
""" |
|
Args: |
|
x (Tensor): input tensor. shape: (bs, seq_len, embed_dim) |
|
|
|
Returns: |
|
Tensor: input tensor with rotary embeddings. shape: (bs, seq_len, embed_dim) |
|
""" |
|
|
|
assert x.ndim == 3, "input must have 3 dimensions: (bs, seq_len, embed_dim)" |
|
|
|
return ( |
|
x |
|
/ torch.sqrt_(torch.mean(torch.square(x), dim=-1) + self.eps)[:, :, None] |
|
* self.trainable |
|
) |
|
|
|
def extra_repr(self) -> str: |
|
return f"dim={self.dim}, eps={self.eps}" |
|
|
|
|
|
class SiLU(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x: Tensor): |
|
""" |
|
Args: |
|
x (Tensor): input |
|
""" |
|
return x * x.sigmoid() |
|
|
|
|
|
class SwiGLU(nn.Module): |
|
def __init__(self, dim: int) -> None: |
|
super().__init__() |
|
self.linear_inp1 = nn.Linear(dim, (8 * dim) // 3, bias=False) |
|
self.linear_inp2 = nn.Linear(dim, (8 * dim) // 3, bias=False) |
|
self.linear_out = nn.Linear((8 * dim) // 3, dim, bias=False) |
|
self.silu = SiLU() |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x: Tensor): |
|
""" |
|
Args: |
|
x (Tensor): input tensor |
|
""" |
|
return self.linear_out(self.silu(self.linear_inp1(x)) * self.linear_inp2(x)) |
|
|
|
|
|
class MistralTokenizer(nn.Module): |
|
def __init__(self, max_length=1024, *args, **kwargs): |
|
super().__init__() |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
"mistralai/Mistral-7B-v0.1", *args, **kwargs |
|
) |
|
self.tokenizer.add_special_tokens({"pad_token": "<pad>"}) |
|
self.special_tokens_ids = { |
|
token: id |
|
for token, id in zip( |
|
self.tokenizer.special_tokens_map.keys(), self.tokenizer.all_special_ids |
|
) |
|
} |
|
self.max_length = max_length |
|
self.pad_token_id = self.tokenizer.pad_token_id |
|
|
|
def forward(self, text): |
|
return self.tokenizer( |
|
text, |
|
return_tensors="pt", |
|
return_attention_mask=False, |
|
max_length=self.max_length, |
|
truncation=True, |
|
padding=True, |
|
padding_side="right", |
|
) |
|
|
|
def convert_ids_to_tokens(self, ids): |
|
return self.tokenizer.convert_ids_to_tokens(ids) |
|
|
|
def decode(self, x): |
|
return self.tokenizer.batch_decode(x) |
|
|
|
def __len__(self): |
|
return len(self.tokenizer) |
|
|
|
|
|
class MultiHeadAttention(nn.Module): |
|
def __init__( |
|
self, |
|
emb_size: int, |
|
n_heads: int, |
|
dropout: float = 0.0, |
|
use_rotary_embeddings: bool = False, |
|
bias_qkv: bool = False, |
|
bias_out: bool = False, |
|
): |
|
super().__init__() |
|
self.emb_size = emb_size |
|
self.n_heads = n_heads |
|
assert ( |
|
self.emb_size % n_heads == 0 |
|
), "Embedding size needs to be divisible by heads" |
|
|
|
self.head_dim = emb_size // n_heads |
|
|
|
self.use_rotary_embeddings = use_rotary_embeddings |
|
if self.use_rotary_embeddings: |
|
self.rotary_embed = RotaryEmbedding(self.head_dim) |
|
|
|
self.qkv = nn.Linear(emb_size, emb_size * 3, bias=bias_qkv) |
|
self.dropout = nn.Dropout(dropout) |
|
self.out = nn.Linear(emb_size, emb_size, bias=bias_out) |
|
|
|
self.scaling = self.head_dim**-0.5 |
|
|
|
def forward(self, x: Tensor, att_mask: Tensor = None): |
|
qkv = self.qkv(x).chunk(3, dim=-1) |
|
q, k, v = map( |
|
lambda t: t.reshape(x.shape[0], -1, self.n_heads, self.head_dim).transpose( |
|
1, 2 |
|
), |
|
qkv, |
|
) |
|
|
|
if self.use_rotary_embeddings: |
|
q, k = self.rotary_embed(q), self.rotary_embed(k) |
|
|
|
dots = ( |
|
torch.matmul(q, k.transpose(-1, -2)) * self.scaling |
|
) |
|
|
|
if att_mask is not None: |
|
dots = dots + att_mask |
|
|
|
attn = self.dropout(torch.softmax(dots, dim=-1)) |
|
out = ( |
|
torch.matmul(attn, v).transpose(1, 2).reshape(x.shape[0], -1, self.emb_size) |
|
) |
|
out = self.out(out) |
|
|
|
return out |
|
|
|
|
|
class LLaMADecoderLayer(nn.Module): |
|
def __init__( |
|
self, |
|
emb_size: int, |
|
n_heads: int, |
|
dropout: float, |
|
) -> None: |
|
super().__init__() |
|
self.emb_size = emb_size |
|
self.multihead_attn = MultiHeadDispatch( |
|
dim_model=emb_size, |
|
num_heads=n_heads, |
|
attention=ScaledDotProduct( |
|
dropout=dropout, |
|
), |
|
bias=(False, False, False, False), |
|
use_rotary_embeddings=True, |
|
) |
|
self.rmsnorm1 = nn.RMSNorm(emb_size, eps=1e-9) |
|
self.rmsnorm2 = nn.RMSNorm(emb_size, eps=1e-9) |
|
self.swiglu = SwiGLU(emb_size) |
|
self.n_heads = n_heads |
|
|
|
def forward(self, in_tuple) -> Tensor: |
|
""" |
|
Args: |
|
in_tuple (tuple[Tensor, Tensor, Tensor]): tuple, containing 3 tensors: |
|
x (Tensor): input tensor (bs, seq_len, dim) |
|
attn_mask (Tensor): attention mask (seq_len, seq_len) |
|
padding_mask (Tensor): padding mask (bs, seq_len) |
|
|
|
Returns: |
|
Tensor: output tensor |
|
""" |
|
assert len(in_tuple) == 2, "input tuple must have 2 elements" |
|
x, mask = in_tuple |
|
|
|
x = self.multihead_attn(self.rmsnorm1(x), att_mask=mask) + x |
|
return self.swiglu(self.rmsnorm2(x)) + x, mask |
|
|
|
|
|
class CustomAttentionLLaMaDecoder(LLaMADecoderLayer): |
|
def __init__( |
|
self, |
|
emb_size: int, |
|
n_heads: int, |
|
dropout: float, |
|
) -> None: |
|
super().__init__(emb_size, n_heads, dropout) |
|
self.multihead_attn = MultiHeadAttention( |
|
emb_size=emb_size, |
|
n_heads=n_heads, |
|
bias_qkv=False, |
|
bias_out=False, |
|
use_rotary_embeddings=True, |
|
dropout=dropout, |
|
) |
|
self.rmsnorm1 = RMSNorm(emb_size, eps=1e-9) |
|
self.rmsnorm2 = RMSNorm(emb_size, eps=1e-9) |
|
|
|
|
|
class LLaMaBase(nn.Module): |
|
def __init__( |
|
self, |
|
embed_dim: int = 512, |
|
n_layers: int = 2, |
|
n_heads: int = 8, |
|
dropout: int = 0.0, |
|
n_chckpnt_segments: int = 1, |
|
tokenizer=MistralTokenizer(), |
|
**kwargs, |
|
): |
|
""" |
|
Args: |
|
n_feats (int): number of input features. |
|
n_class (int): number of classes. |
|
fc_hidden (int): number of hidden features. |
|
""" |
|
super().__init__() |
|
|
|
self.tokenizer = tokenizer |
|
self.vocab_len = len(tokenizer) |
|
self.n_heads = n_heads |
|
self.dropout = dropout |
|
self.n_layers = n_layers |
|
self.embed_dim = embed_dim |
|
self.n_segments = n_chckpnt_segments |
|
|
|
self.embed = nn.Embedding( |
|
self.vocab_len, embed_dim, padding_idx=self.tokenizer.pad_token_id |
|
) |
|
self.head = nn.Linear(embed_dim, self.vocab_len, bias=False) |
|
|
|
def forward(self, src: Tensor, attn_mask: Tensor, pad_mask: Tensor, **batch): |
|
""" |
|
Model forward method. |
|
|
|
Args: |
|
tokenized (Tensor): input text. shape: (batch_size, seq_len) |
|
Returns: |
|
output (dict): output dict containing logits. |
|
""" |
|
|
|
raise NotImplementedError |
|
|
|
def __str__(self): |
|
""" |
|
Model prints with the number of parameters. |
|
""" |
|
all_parameters = sum([p.numel() for p in self.parameters()]) |
|
trainable_parameters = sum( |
|
[p.numel() for p in self.parameters() if p.requires_grad] |
|
) |
|
embedding_parameters = sum([p.numel() for p in self.embed.parameters()]) |
|
|
|
result_info = super().__str__() |
|
result_info = result_info + f"\nAll parameters: {all_parameters}" |
|
result_info = result_info + f"\nTrainable parameters: {trainable_parameters}" |
|
result_info = ( |
|
result_info |
|
+ f"\nWithout embedding: {trainable_parameters - embedding_parameters}" |
|
) |
|
|
|
return result_info |
|
|
|
|
|
class CustomAttentionLLaMa(LLaMaBase): |
|
def __init__( |
|
self, |
|
embed_dim: int = 512, |
|
n_layers: int = 2, |
|
n_heads: int = 8, |
|
dropout: int = 0.0, |
|
n_chckpnt_segments: int = 1, |
|
tokenizer=MistralTokenizer(), |
|
**kwargs, |
|
): |
|
""" |
|
Args: |
|
n_feats (int): number of input features. |
|
n_class (int): number of classes. |
|
fc_hidden (int): number of hidden features. |
|
""" |
|
super().__init__( |
|
embed_dim, |
|
n_layers, |
|
n_heads, |
|
dropout, |
|
n_chckpnt_segments, |
|
tokenizer, |
|
) |
|
|
|
self.decoders = nn.Sequential( |
|
*[ |
|
CustomAttentionLLaMaDecoder( |
|
emb_size=embed_dim, n_heads=self.n_heads, dropout=dropout |
|
) |
|
for _ in range(n_layers) |
|
] |
|
) |
|
self.rmsnorm = RMSNorm(embed_dim, eps=1e-9) |
|
|
|
def forward(self, src: Tensor, attn_mask: Tensor, pad_mask: Tensor, **batch): |
|
""" |
|
Model forward method. |
|
|
|
Args: |
|
tokenized (Tensor): input text. shape: (batch_size, seq_len) |
|
Returns: |
|
output (dict): output dict containing logits. |
|
""" |
|
x = self.embed(src) |
|
sizes = x.shape |
|
mask = maybe_merge_masks( |
|
attn_mask, pad_mask, sizes[0], sizes[1], self.n_heads |
|
).view(x.shape[0], self.n_heads, sizes[1], sizes[1]) |
|
x, _ = checkpoint_sequential(self.decoders, self.n_segments, input=(x, mask)) |
|
|
|
|
|
|
|
logits = self.head(self.rmsnorm(x)) |
|
return { |
|
"logits": logits.permute(0, 2, 1) |
|
} |
|
|