|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple |
|
|
|
import torch |
|
import triton |
|
import triton.language as tl |
|
|
|
from fla.utils import contiguous |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BD': 32}, num_warps=1), |
|
triton.Config({'BD': 32}, num_warps=2), |
|
triton.Config({'BD': 32}, num_warps=4), |
|
triton.Config({'BD': 32}, num_warps=8), |
|
triton.Config({'BD': 64}, num_warps=1), |
|
triton.Config({'BD': 64}, num_warps=2), |
|
triton.Config({'BD': 64}, num_warps=4), |
|
triton.Config({'BD': 64}, num_warps=8), |
|
triton.Config({'BD': 128}, num_warps=1), |
|
triton.Config({'BD': 128}, num_warps=2), |
|
triton.Config({'BD': 128}, num_warps=4), |
|
triton.Config({'BD': 128}, num_warps=8), |
|
], |
|
key=['D'] |
|
) |
|
@triton.jit |
|
def chunk_hgrn_fwd_kernel_h( |
|
x, |
|
g, |
|
gc, |
|
o, |
|
h0, |
|
T: tl.constexpr, |
|
D: tl.constexpr, |
|
BT: tl.constexpr, |
|
BD: tl.constexpr, |
|
USE_INITIAL_STATE: tl.constexpr |
|
): |
|
i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
o_d = i_d * BD + tl.arange(0, BD) |
|
mask = o_d < D |
|
|
|
p_x = x + i_b * T * D + i_t * BT * D + o_d |
|
p_g = g + i_b * T * D + i_t * BT * D + o_d |
|
p_gc = gc + i_b * T * D + i_t * BT * D + o_d |
|
p_o = o + i_b * T * D + i_t * BT * D + o_d |
|
|
|
b_h = tl.zeros([BD], dtype=tl.float32) |
|
b_gc = tl.zeros([BD], dtype=tl.float32) |
|
if USE_INITIAL_STATE: |
|
if i_t == 0: |
|
b_h += tl.load(h0 + i_b * D + o_d, mask=mask, other=0).to(tl.float32) |
|
for i in range(0, BT): |
|
mask_t = mask & ((i_t * BT + i) < T) |
|
b_x = tl.load(p_x, mask=mask_t, other=0).to(tl.float32) |
|
b_g = tl.load(p_g, mask=mask_t, other=0).to(tl.float32) |
|
b_h = tl.exp(b_g) * b_h + b_x |
|
b_gc = b_gc + b_g |
|
tl.store(p_gc, b_gc.to(p_o.dtype.element_ty), mask=mask_t) |
|
tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask_t) |
|
|
|
p_x += D |
|
p_g += D |
|
p_gc += D |
|
p_o += D |
|
|
|
|
|
@triton.jit |
|
def chunk_hgrn_fwd_kernel_o( |
|
gc, |
|
o, |
|
s_b, |
|
s_t, |
|
s_d, |
|
T: tl.constexpr, |
|
D: tl.constexpr, |
|
BT: tl.constexpr, |
|
BD: tl.constexpr |
|
): |
|
i_d, i_b = tl.program_id(0), tl.program_id(1) |
|
o_d = i_d * BD + tl.arange(0, BD) |
|
mask = o_d < D |
|
|
|
for i_t in range(1, tl.cdiv(T, BT)): |
|
p_gc = tl.make_block_ptr(gc + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) |
|
p_o = tl.make_block_ptr(o + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) |
|
|
|
|
|
b_h0 = tl.load(o + i_b * T * D + i_t * BT * D - D + o_d, mask=mask, other=0).to(tl.float32) |
|
|
|
b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32) |
|
b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) |
|
b_o = b_o + tl.exp(b_gc) * b_h0[None, :] |
|
tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
|
@triton.autotune( |
|
configs=[ |
|
triton.Config({'BD': 32}, num_warps=1), |
|
triton.Config({'BD': 32}, num_warps=2), |
|
triton.Config({'BD': 32}, num_warps=4), |
|
triton.Config({'BD': 32}, num_warps=8), |
|
triton.Config({'BD': 64}, num_warps=1), |
|
triton.Config({'BD': 64}, num_warps=2), |
|
triton.Config({'BD': 64}, num_warps=4), |
|
triton.Config({'BD': 64}, num_warps=8), |
|
triton.Config({'BD': 128}, num_warps=1), |
|
triton.Config({'BD': 128}, num_warps=2), |
|
triton.Config({'BD': 128}, num_warps=4), |
|
triton.Config({'BD': 128}, num_warps=8), |
|
], |
|
key=['D'] |
|
) |
|
@triton.jit |
|
def chunk_hgrn_bwd_kernel_h( |
|
g, |
|
gc, |
|
dx, |
|
do, |
|
T: tl.constexpr, |
|
D: tl.constexpr, |
|
BT: tl.constexpr, |
|
BD: tl.constexpr |
|
): |
|
i_d, i_t, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
|
o_d = i_d * BD + tl.arange(0, BD) |
|
mask = o_d < D |
|
BC = min(BT, T - i_t * BT) |
|
NT = tl.num_programs(1) |
|
|
|
p_g = g + (i_b * T + i_t * BT + BC - 1) * D + o_d |
|
p_gc = gc + (i_b * T + i_t * BT + BC - 1) * D + o_d |
|
p_dx = dx + (i_b * T + i_t * BT + BC - 1) * D + o_d |
|
p_do = do + (i_b * T + i_t * BT + BC - 1) * D + o_d |
|
|
|
if i_t == NT - 1: |
|
b_gc = tl.zeros([BD], dtype=tl.float32) |
|
else: |
|
b_gc = tl.load(g + (i_b * T + i_t * BT + BT) * D + o_d, mask=mask, other=0).to(tl.float32) |
|
b_dh = tl.zeros([BD], dtype=tl.float32) |
|
for _ in range(BC - 1, -1, -1): |
|
tl.store(p_gc, b_gc.to(p_gc.dtype.element_ty), mask=mask) |
|
|
|
b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32) |
|
b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32) |
|
|
|
b_gc = b_gc + b_g |
|
b_dh = b_dh + b_do |
|
b_dx = b_dh |
|
b_dh = b_dh * tl.exp(b_g) |
|
|
|
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask) |
|
|
|
p_g -= D |
|
p_gc -= D |
|
p_dx -= D |
|
p_do -= D |
|
|
|
|
|
@triton.jit |
|
def chunk_hgrn_bwd_kernel_o( |
|
g, |
|
gc, |
|
o, |
|
dx, |
|
dg, |
|
s_b, |
|
s_t, |
|
s_d, |
|
T: tl.constexpr, |
|
D: tl.constexpr, |
|
BT: tl.constexpr, |
|
BD: tl.constexpr |
|
): |
|
i_d, i_b = tl.program_id(0), tl.program_id(1) |
|
o_d = i_d * BD + tl.arange(0, BD) |
|
mask = o_d < D |
|
|
|
for i_t in range(tl.cdiv(T, BT) - 1, -1, -1): |
|
p_g = tl.make_block_ptr(g + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) |
|
p_gc = tl.make_block_ptr(gc + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) |
|
p_o = tl.make_block_ptr(o + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT - 1, i_d * BD), (BT, BD), (1, 0)) |
|
p_dx = tl.make_block_ptr(dx + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) |
|
p_dg = tl.make_block_ptr(dg + i_b * s_b, (T, D), (s_t, s_d), (i_t * BT, i_d * BD), (BT, BD), (1, 0)) |
|
|
|
|
|
mask_t = mask & ((i_t + 1) * BT < T) |
|
b_ht = tl.load(dx + i_b * T * D + (i_t + 1) * BT * D + o_d, mask=mask_t, other=0).to(tl.float32) |
|
|
|
b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) |
|
b_gc = tl.load(p_gc, boundary_check=(0, 1)).to(tl.float32) |
|
b_o = tl.load(p_o, boundary_check=(0, 1)).to(tl.float32) |
|
b_dx = tl.load(p_dx, boundary_check=(0, 1)).to(tl.float32) |
|
|
|
b_dx = b_dx + tl.exp(b_gc) * b_ht[None, :] |
|
b_dg = b_o * b_dx * tl.exp(b_g) |
|
tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), boundary_check=(0, 1)) |
|
tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) |
|
|
|
|
|
class ChunkHGRNFunction(torch.autograd.Function): |
|
|
|
@staticmethod |
|
@contiguous |
|
def forward(ctx, x, g, initial_state=None, output_final_state=False): |
|
B, T, D = x.shape |
|
BT, BD = 128, min(64, triton.next_power_of_2(D)) |
|
num_warps = 8 if BD == 64 else 4 |
|
|
|
gc = torch.empty_like(g, dtype=torch.float) |
|
o = torch.empty_like(x, dtype=torch.float) |
|
def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B) |
|
chunk_hgrn_fwd_kernel_h[grid]( |
|
x, g, gc, o, initial_state, |
|
T=T, D=D, BT=BT, |
|
USE_INITIAL_STATE=initial_state is not None |
|
) |
|
def grid(meta): return (triton.cdiv(D, meta['BD']), B) |
|
chunk_hgrn_fwd_kernel_o[grid]( |
|
gc, o, |
|
o.stride(-3), o.stride(-2), o.stride(-1), |
|
T=T, D=D, BT=BT, BD=BD, |
|
num_warps=num_warps |
|
) |
|
final_state = None |
|
if output_final_state: |
|
final_state = o[:, -1].clone() |
|
o = o.to(x.dtype) |
|
ctx.save_for_backward(g, o, initial_state) |
|
return o, final_state |
|
|
|
@staticmethod |
|
@contiguous |
|
def backward(ctx, do, dht=None): |
|
g, o, initial_state = ctx.saved_tensors |
|
B, T, D = do.shape |
|
BT, BD = 128, min(64, triton.next_power_of_2(D)) |
|
num_warps = 8 if BD == 64 else 4 |
|
|
|
gc = torch.empty_like(g, dtype=torch.float) |
|
dx = torch.empty_like(o, dtype=torch.float) |
|
def grid(meta): return (triton.cdiv(D, meta['BD']), triton.cdiv(T, meta['BT']), B) |
|
chunk_hgrn_bwd_kernel_h[grid]( |
|
g, gc, dx, do, |
|
T=T, D=D, BT=BT |
|
) |
|
|
|
dg = torch.empty_like(g, dtype=torch.float) |
|
def grid(meta): return (triton.cdiv(D, meta['BD']), B) |
|
chunk_hgrn_bwd_kernel_o[grid]( |
|
g, gc, o, dx, dg, |
|
o.stride(-3), o.stride(-2), o.stride(-1), |
|
T=T, D=D, BT=BT, BD=BD, |
|
num_warps=num_warps |
|
) |
|
if initial_state is not None: |
|
dg[:, 0] = (initial_state * dx[:, 0] * g[:, 0].float().exp()).to(dg.dtype) |
|
|
|
return dx.to(o.dtype), dg, None, None |
|
|
|
|
|
def chunk_hgrn( |
|
x: torch.Tensor, |
|
g: torch.Tensor, |
|
initial_state: torch.Tensor = None, |
|
output_final_state: bool = False |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
return ChunkHGRNFunction.apply(x, g, initial_state, output_final_state) |
|
|