gla-16M-test / fla /ops /rwkv6 /chunk_naive.py
zaydzuhri's picture
Training in progress, step 2048
2f9282b verified
# -*- coding: utf-8 -*-
import torch
from einops import rearrange
def naive_chunk_rwkv6(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
chunk_size: int = 32
):
assert q.shape[-2] % chunk_size == 0
orig_dtype = q.dtype
num_chunk = q.shape[-2] // chunk_size
u = u.unsqueeze(0)
q, k, v, w = map(lambda x: rearrange(x, 'b h (n c) d -> b h n c d', c=chunk_size).float(), (q, k, v, w))
w_cumsum = w.cumsum(-2)
kw = k * (w_cumsum[..., -1, None, :] - w_cumsum).exp()
wkv = kw.transpose(-1, -2) @ v
wkv_new = torch.zeros_like(wkv)
for i in range(num_chunk - 1):
wkv_new[:, :, i+1] = (wkv_new[:, :, i] * w_cumsum[:, :, i, -1, :, None].exp()) + wkv[:, :, i]
o_inter = torch.einsum('b h n d p, b h n c d -> b h n c p', wkv_new, (q * (w_cumsum - w).exp()))
o_intra = torch.zeros_like(o_inter)
for i in range(chunk_size):
attn = (q[:, :, :, i, None] * k * (w_cumsum[:, :, :, i, None] - w[:, :, :, i, None] - w_cumsum).exp()).sum(-1)
mask = (torch.arange(0, chunk_size) < i).to(attn.device)
attn.masked_fill_(~mask, 0)
intra_inter_o = (attn.unsqueeze(-1) * v).sum(-2)
intra_intra_o = (q[:, :, :, i] * u.unsqueeze(2) * k[:, :, :, i]).sum(-1).unsqueeze(-1) * v[:, :, :, i]
o_intra[:, :, :, i] = intra_inter_o + intra_intra_o
o = o_inter + o_intra
return rearrange(o, 'b h n c d -> b h (n c) d').to(orig_dtype)