|
|
|
|
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
from einops import rearrange |
|
|
|
from fla.ops.linear_attn.utils import normalize_output |
|
|
|
|
|
def naive_chunk_linear_attn( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
scale: Optional[float] = None, |
|
normalize: bool = False |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
if scale is None: |
|
scale = q.shape[-1] ** -0.5 |
|
chunk_size = 64 |
|
q = rearrange(q, 'b h (n c) d -> b h n c d', c=chunk_size) * scale |
|
k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) |
|
v = rearrange(v, 'b h (n c) d -> b h n c d', c=chunk_size) |
|
kv = k.transpose(-1, -2) @ v |
|
kv = kv.cumsum(2) |
|
kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) |
|
inter = q @ kv |
|
intra = (( |
|
q @ k.transpose(-1, -2)).masked_fill_( |
|
torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), |
|
0 |
|
)) @ v |
|
o = inter + intra |
|
if normalize: |
|
o = normalize_output(q * scale, k, o) |
|
return rearrange(o, 'b h n c d -> b h (n c) d') |
|
|