|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
|
|
|
|
def ceildiv(a, b): |
|
return -(a // -b) |
|
|
|
|
|
def naive_recurrent_gla( |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
gk: torch.Tensor, |
|
initial_state: Optional[torch.Tensor] = None, |
|
output_final_state: bool = False |
|
): |
|
dtype = q.dtype |
|
q, k, v, gk = map(lambda x: x.float(), (q, k, v, gk)) |
|
B, H, T, K, V = *q.shape, v.shape[-1] |
|
o = torch.zeros_like(v) |
|
scale = K ** -0.5 |
|
|
|
h = q.new_zeros(B, H, K, V, dtype=torch.float32) |
|
if initial_state is not None: |
|
h += initial_state.float() |
|
|
|
for i in range(T): |
|
q_i = q[:, :, i] * scale |
|
k_i = k[:, :, i] |
|
v_i = v[:, :, i] |
|
gk_i = gk[:, :, i].exp() |
|
kv_i = k_i[..., None] * v_i[..., None, :] |
|
h = h * gk_i[..., None] + kv_i |
|
o[:, :, i] = (q_i[..., None] * h).sum(-2) |
|
|
|
if not output_final_state: |
|
h = None |
|
return o.to(dtype), h |
|
|