|
|
|
|
|
|
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
|
|
from fla.ops.common.fused_recurrent import fused_recurrent |
|
|
|
|
|
def fused_recurrent_simple_gla( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
g: torch.Tensor, |
|
scale: Optional[float] = None, |
|
initial_state: Optional[torch.Tensor] = None, |
|
output_final_state: bool = False, |
|
reverse: bool = False, |
|
offsets: Optional[torch.LongTensor] = None, |
|
head_first: bool = True |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
r""" |
|
Args: |
|
q (torch.Tensor): |
|
queries of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. |
|
k (torch.Tensor): |
|
keys of shape `[B, H, T, K]` if `head_first=True` else `[B, T, H, K]`. |
|
v (torch.Tensor): |
|
values of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. |
|
g (torch.Tensor): |
|
Forget gates of shape `[B, H, T]` if `head_first=True` else `[B, T, H]`. |
|
Compared to GLA, the gating is head-wise instead of elementwise. |
|
scale (Optional[int]): |
|
Scale factor for the attention scores. |
|
If not provided, it will default to `1 / sqrt(K)`. Default: `None`. |
|
initial_state (Optional[torch.Tensor]): |
|
Initial state of shape `[N, H, K, V]` for `N` input sequences. |
|
For equal-length input sequences, `N` equals the batch size `B`. |
|
Default: `None`. |
|
output_final_state (Optional[bool]): |
|
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. |
|
reverse (Optional[bool]): |
|
If `True`, process the state passing in reverse order. Default: `False`. |
|
offsets (Optional[torch.LongTensor]): |
|
Offsets of shape `[N+1]` defining the bos/eos positions of `N` variable-length sequences in the batch. |
|
For example, |
|
if `offsets` is `[0, 1, 3, 6, 10, 15]`, there are `N=5` sequences with lengths 1, 2, 3, 4 and 5 respectively. |
|
If provided, the inputs are concatenated and the batch size `B` is expected to be 1. |
|
Default: `None`. |
|
head_first (Optional[bool]): |
|
Whether the inputs are in the head-first format, which is not supported for variable-length inputs. |
|
Default: `True`. |
|
|
|
Returns: |
|
o (torch.Tensor): |
|
Outputs of shape `[B, H, T, V]` if `head_first=True` else `[B, T, H, V]`. |
|
final_state (torch.Tensor): |
|
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. |
|
|
|
Examples:: |
|
>>> import torch |
|
>>> import torch.nn.functional as F |
|
>>> from einops import rearrange |
|
>>> from fla.ops.simple_gla import fused_recurrent_simple_gla |
|
# inputs with equal lengths |
|
>>> B, T, H, K, V = 4, 2048, 4, 512, 512 |
|
>>> q = torch.randn(B, T, H, K, device='cuda') |
|
>>> k = torch.randn(B, T, H, K, device='cuda') |
|
>>> v = torch.randn(B, T, H, V, device='cuda') |
|
>>> g = F.logsigmoid(torch.randn(B, T, H, K, device='cuda')) |
|
>>> h0 = torch.randn(B, H, K, V, device='cuda') |
|
>>> o, ht = fused_recurrent_simple_gla(q, k, v, g, |
|
initial_state=h0, |
|
output_final_state=True, |
|
head_first=False) |
|
# for variable-length inputs, the batch size `B` is expected to be 1 and `offsets` is required |
|
>>> q, k, v, g = map(lambda x: rearrange(x, 'b t h d -> 1 (b t) h d'), (q, k, v, g)) |
|
# for a batch with 4 sequences, offsets with 5 start/end positions are expected |
|
>>> offsets = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) |
|
>>> o_var, ht_var = fused_recurrent_simple_gla(q, k, v, g, |
|
initial_state=h0, |
|
output_final_state=True, |
|
offsets=offsets, |
|
head_first=False) |
|
>>> assert o.allclose(o_var.view(o.shape)) |
|
>>> assert ht.allclose(ht_var) |
|
""" |
|
if offsets is not None: |
|
if q.shape[0] != 1: |
|
raise ValueError(f"The batch size is expected to be 1 rather than {q.shape[0]} when using `offsets`." |
|
f"Please flatten variable-length inputs before processing.") |
|
if head_first: |
|
raise RuntimeError("Sequences with variable lengths are not supported for head-first mode") |
|
if initial_state is not None and initial_state.shape[0] != len(offsets) - 1: |
|
raise ValueError(f"The number of initial states is expected to be equal to the number of input sequences, " |
|
f"i.e., {len(offsets) - 1} rather than {initial_state.shape[0]}.") |
|
if scale is None: |
|
scale = k.shape[-1] ** -0.5 |
|
o, final_state = fused_recurrent( |
|
q=q, |
|
k=k, |
|
v=v, |
|
g=g, |
|
scale=scale, |
|
initial_state=initial_state, |
|
output_final_state=output_final_state, |
|
reverse=reverse, |
|
offsets=offsets, |
|
head_first=head_first |
|
) |
|
return o, final_state |
|
|