# -*- coding: utf-8 -*- # Copyright (c) 2024, Songlin Yang, Yu Zhang from __future__ import annotations from typing import TYPE_CHECKING, Optional, Tuple import torch import torch.nn as nn from einops import rearrange, repeat from transformers.activations import ACT2FN from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution from fla.modules.rotary import RotaryEmbedding from fla.ops.retention import (chunk_retention, fused_chunk_retention, fused_recurrent_retention, parallel_retention) if TYPE_CHECKING: from fla.models.utils import Cache class MultiScaleRetention(nn.Module): r""" The layer implementaion for [Retentive Network: A Successor to Transformer for Large Language Models](https://arxiv.org/pdf/2307.08621.pdf). # noqa Args: mode (str, Optional): Which Retention kernel to use. Currently available: `chunk`, `fused_recurrent`, `parallel`, and `fused_chunk`. Default: `fused_chunk`. hidden_size (int, Optional): The hidden size of the input. Default: 1024. expand_k (float, Optional): The expansion ratio for the key dim. Default: 1.0. expand_v (float, Optional): The expansion ratio for the value dim. Default: 2.0. num_heads (int, Optional): The number of heads. Default: 8. num_kv_heads (int, Optional): The number of key/value heads, used for MQA. Default: None. feature_map (str, Optional): Feature map function applied to queries/keys. Default: None. use_short_conv (bool, Optional): Whether to use short convolutions. Default: `False`. conv_size (int, Optional): The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. conv_bias (bool, Optional): Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. use_output_gate (bool, Optional): Whether to use output gate. Default: `True`. gate_fn (str, Optional): The activation function for the output gate. Default: `swish`. elementwise_affine (bool, Optional): If `True`, applies elementwise affine to LayerNorm with learnable parameters. Default: `True`. norm_eps (float, Optional): The epsilon value for the layernorm/rmsnorm layer. Default: 1e-5. fuse_norm (bool, Optional): Whether to fuse the norm and the output gate for better memory footprint. Default: `True`. layer_idx (int, Optional): The index of the layer. Default: None. """ def __init__( self, mode: str = 'chunk', hidden_size: int = 1024, expand_k: float = 1.0, expand_v: float = 2.0, num_heads: int = 8, num_kv_heads: Optional[int] = None, feature_map: Optional[str] = None, use_short_conv: bool = False, conv_size: int = 4, conv_bias: bool = False, use_output_gate: bool = True, gate_fn: str = 'swish', elementwise_affine: Optional[bool] = True, norm_eps: float = 1e-5, fuse_norm: bool = True, layer_idx: int = None, **kwargs ) -> MultiScaleRetention: super().__init__() self.mode = mode self.hidden_size = hidden_size self.expand_k = expand_k self.expand_v = expand_v self.num_heads = num_heads self.num_kv_heads = num_kv_heads if num_kv_heads is not None else num_heads self.num_kv_groups = self.num_heads // self.num_kv_heads self.feature_map_fn = ACT2FN[feature_map] if feature_map is not None else None self.use_short_conv = use_short_conv self.conv_size = conv_size self.conv_bias = conv_bias self.use_output_gate = use_output_gate self.key_dim = int(hidden_size * expand_k) self.value_dim = int(hidden_size * expand_v) self.key_dim_per_group = self.key_dim // self.num_kv_groups self.value_dim_per_group = self.value_dim // self.num_kv_groups self.layer_idx = layer_idx assert mode in ['chunk', 'fused_chunk', 'parallel', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." assert self.key_dim % num_heads == 0, f"key dim must be divisible by num_heads of {num_heads}" assert self.value_dim % num_heads == 0, f"value dim must be divisible by num_heads of {num_heads}" self.head_qk_dim = self.key_dim // num_heads self.head_v_dim = self.value_dim // num_heads self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) self.k_proj = nn.Linear(hidden_size, self.key_dim_per_group, bias=False) self.v_proj = nn.Linear(hidden_size, self.value_dim_per_group, bias=False) if self.use_output_gate: self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) if use_short_conv: self.conv_size = conv_size self.q_conv1d = ShortConvolution(self.key_dim, conv_size, activation='silu') self.k_conv1d = ShortConvolution(self.key_dim_per_group, conv_size, activation='silu') self.v_conv1d = ShortConvolution(self.value_dim_per_group, conv_size, activation='silu') self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) if gate_fn == 'swish' and fuse_norm and use_output_gate: self.g_norm_swish_gate = FusedRMSNormSwishGate(self.head_v_dim, elementwise_affine, norm_eps) self.fuse_norm_and_gate = True else: self.fuse_norm_and_gate = False self.g_norm = RMSNorm(hidden_size=self.head_v_dim, elementwise_affine=elementwise_affine, eps=norm_eps) self.gate_fn = ACT2FN[gate_fn] # TODO: fix this issue # https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/triton/rotary.py#L180 # Ideally, we would want to support arbitrary d_head_qk assert self.head_qk_dim <= 256, "head_qk_dim must be less than or equal to 256" self.rotary = RotaryEmbedding(dim=self.head_qk_dim) self.apply(self._initialize_weights) def _initialize_weights(self, module: nn.Module): if getattr(module, "_is_hf_initialized", False): return if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) if module.bias is not None: nn.init.zeros_(module.bias) module._is_hf_initialized = True def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, **kwargs ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache]]: if attention_mask is not None: assert len(attention_mask.shape) == 2, ( "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " "for padding purposes (0 indicating padding). " "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." ) # launching the triton kernel for just one token will actually be slower mode = 'fused_recurrent' if hidden_states.shape[1] == 1 else self.mode last_state = None if past_key_values is not None and len(past_key_values) > self.layer_idx: last_state = past_key_values[self.layer_idx] if self.use_short_conv: conv_state_q, conv_state_k, conv_state_v = None, None, None if last_state is not None: conv_state_q, conv_state_k, conv_state_v = last_state['conv_state'] conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None q, conv_state_q = self.q_conv1d(x=self.q_proj(hidden_states), mask=conv_mask, cache=conv_state_q, output_final_state=use_cache) k, conv_state_k = self.k_conv1d(x=self.k_proj(hidden_states), mask=conv_mask, cache=conv_state_k, output_final_state=use_cache) v, conv_state_v = self.v_conv1d(x=self.v_proj(hidden_states), mask=conv_mask, cache=conv_state_v, output_final_state=use_cache) else: q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) # dealing with left-padding if attention_mask is not None: v = v.mul_(attention_mask[:, -v.shape[-2]:, None]) q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads) k = rearrange(k, '... (h d) -> ... h d', h=self.num_kv_heads) if self.feature_map_fn is not None: q, k = map(self.feature_map_fn, (q, k)) seqlen_offset, max_seqlen = 0, q.shape[1] if past_key_values is not None: seqlen_offset = past_key_values.get_seq_length(self.layer_idx) max_seqlen = q.shape[1] + seqlen_offset if attention_mask is not None: # to deliminate the offsets of padding tokens seqlen_offset = (seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]).clamp(min=0) max_seqlen = q.shape[1] + max(seqlen_offset) q, k = self.rotary(q, k, seqlen_offset, max_seqlen) if self.num_kv_groups > 1: k = repeat(k, 'b t h d -> b t (h g) d', h=self.num_kv_heads, g=self.num_kv_groups) v = repeat(v, 'b t (h d) -> b t (h g) d', h=self.num_kv_heads, g=self.num_kv_groups) else: k, v = rearrange(k, 'b t h d -> b t h d'), rearrange(v, 'b t (h d) -> b t h d', h=self.num_kv_heads) recurrent_state = last_state['recurrent_state'] if last_state is not None else None if mode == 'chunk': o, recurrent_state = chunk_retention( q=q, k=k, v=v, initial_state=recurrent_state, output_final_state=use_cache, head_first=False ) elif mode == 'fused_chunk': o, recurrent_state = fused_chunk_retention( q=q, k=k, v=v, initial_state=recurrent_state, output_final_state=use_cache, head_first=False ) elif mode == 'parallel': o, recurrent_state = parallel_retention(q, k, v, head_first=False) elif mode == 'fused_recurrent': o, recurrent_state = fused_recurrent_retention( q=q, k=k, v=v, initial_state=recurrent_state, output_final_state=use_cache, head_first=False ) else: raise NotImplementedError(f"Not supported mode `{mode}`.") if past_key_values is not None: past_key_values.update( recurrent_state=recurrent_state, conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, layer_idx=self.layer_idx, offset=q.shape[2] ) if self.use_output_gate: g = self.g_proj(hidden_states) if self.fuse_norm_and_gate: g = rearrange(g, 'b t (h d) -> b t h d', h=self.num_heads) o = self.g_norm_swish_gate(o, g) o = rearrange(o, 'b t h d -> b t (h d)') else: o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)') o = o * self.gate_fn(g) else: o = rearrange(self.g_norm(o), 'b t h d -> b t (h d)') o = self.o_proj(o) return o, None, past_key_values def state_size(self, **kwargs) -> int: state_size = self.key_dim * self.head_v_dim for module in self.children(): if isinstance(module, ShortConvolution): state_size += module.state_size return state_size