|
|
|
|
|
import torch |
|
from einops import rearrange |
|
|
|
|
|
def naive_chunk_rwkv6( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
w: torch.Tensor, |
|
u: torch.Tensor, |
|
chunk_size: int = 32 |
|
): |
|
assert q.shape[-2] % chunk_size == 0 |
|
orig_dtype = q.dtype |
|
num_chunk = q.shape[-2] // chunk_size |
|
u = u.unsqueeze(0) |
|
|
|
q, k, v, w = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size).float(), (q, k, v, w)) |
|
|
|
w_cumsum = w.cumsum(-2) |
|
|
|
kw = k * (w_cumsum[..., -1, None, :] - w_cumsum).exp() |
|
wkv = kw.transpose(-1, -2) @ v |
|
|
|
wkv_new = torch.zeros_like(wkv) |
|
|
|
for i in range(num_chunk - 1): |
|
wkv_new[:, :, i+1] = (wkv_new[:, :, i] * w_cumsum[:, :, i, -1, :, None].exp()) + wkv[:, :, i] |
|
|
|
o_inter = torch.einsum('b h n d p, b h n c d -> b h n c p', wkv_new, (q * (w_cumsum - w).exp())) |
|
|
|
o_intra = torch.zeros_like(o_inter) |
|
for i in range(chunk_size): |
|
attn = (q[:, :, :, i, None] * k * (w_cumsum[:, :, :, i, None] - w[:, :, :, i, None] - w_cumsum).exp()).sum(-1) |
|
mask = (torch.arange(0, chunk_size) < i).to(attn.device) |
|
attn.masked_fill_(~mask, 0) |
|
intra_inter_o = (attn.unsqueeze(-1) * v).sum(-2) |
|
intra_intra_o = (q[:, :, :, i] * u.unsqueeze(2) * k[:, :, :, i]).sum(-1).unsqueeze(-1) * v[:, :, :, i] |
|
o_intra[:, :, :, i] = intra_inter_o + intra_intra_o |
|
o = o_inter + o_intra |
|
return rearrange(o, 'b h n c d -> b h (n c) d').to(orig_dtype) |
|
|