|
|
|
|
|
|
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
|
|
from fla.ops.simple_gla.fused_recurrent import fused_recurrent_simple_gla |
|
|
|
|
|
def fused_recurrent_lightning_attn( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
layer_idx: int, |
|
num_layers: int, |
|
scale: Optional[float] = None, |
|
initial_state: Optional[torch.Tensor] = None, |
|
output_final_state: bool = False, |
|
reverse: bool = False, |
|
cu_seqlens: Optional[torch.LongTensor] = None, |
|
head_first: bool = True |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
r""" |
|
Args: |
|
q (torch.Tensor): |
|
queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. |
|
k (torch.Tensor): |
|
keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. |
|
v (torch.Tensor): |
|
values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. |
|
layer_idx (int): |
|
The index of the current layer. |
|
num_layers (int): |
|
The total number of layers. Both `layer_idx` and `num_layers` are used to compute the decay factor. |
|
scale (Optional[int]): |
|
Scale factor for the attention scores. |
|
If not provided, it will default to `1 / sqrt(K)`. Default: `None`. |
|
initial_state (Optional[torch.Tensor]): |
|
Initial state of shape `[N, H, K, V]` for `N` input sequences. |
|
For equal-length input sequences, `N` equals the batch size `B`. |
|
Default: `None`. |
|
output_final_state (Optional[bool]): |
|
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. |
|
cu_seqlens (torch.LongTensor): |
|
Cumulative sequence lengths of shape `[N+1]` used for variable-length training, |
|
consistent with the FlashAttention API. |
|
head_first (Optional[bool]): |
|
Whether the inputs are in the head-first format, which is not supported for variable-length inputs. |
|
Default: `True`. |
|
|
|
Returns: |
|
o (torch.Tensor): |
|
Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. |
|
final_state (torch.Tensor): |
|
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. |
|
""" |
|
H = q.shape[1] if head_first else q.shape[2] |
|
s = -(8 / H * (1 - layer_idx / num_layers)) * q.new_tensor(range(H), dtype=torch.float) |
|
if head_first: |
|
g = s[None, :, None].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() |
|
else: |
|
g = s[None, None, :].expand(q.shape[0], q.shape[1], q.shape[2]).contiguous() |
|
return fused_recurrent_simple_gla( |
|
q=q, |
|
k=k, |
|
v=v, |
|
g=g, |
|
scale=scale, |
|
initial_state=initial_state, |
|
output_final_state=output_final_state, |
|
reverse=reverse, |
|
cu_seqlens=cu_seqlens, |
|
head_first=head_first |
|
) |
|
|