diff --git "a/modeling_blocks_rwkv7.py" "b/modeling_blocks_rwkv7.py" new file mode 100644--- /dev/null +++ "b/modeling_blocks_rwkv7.py" @@ -0,0 +1,2426 @@ +# ========== AUTO GENERATED FILE ========= +# This file is auto generated by 'hf_builder.py', do not edit this file directly +# As part of the RWKV/RWKV-block project +# ========== =================== ========= +# ---------------- +# block/kernel/rwkv7_attn_pytorch.py +# ---------------- +import torch + +# Enable tensorfloat 32 +torch.set_float32_matmul_precision('high') + +# Handles the RWKV v7 attention mechanic, in pure pytorch +def rwkv7_attn_pytorch( + r,w,k,v, kk,a, + BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, + xx, wkv_state_in +): + + ### Reference implement + # return rwkv7_attn_pytorch_ref( + # r,w,k,v, kk,a, + # BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, + # xx, wkv_state_in + # ) + ### + + # # This works, but it has too much of a vram overhead + ### + # return rwkv7_attn_pytorch_v2_chunk_w_compile_break( + # r,w,k,v, kk,a, + # BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, + # xx, wkv_state_in + # ) + ### + + # > per 9k chunk, per block, on a 4090 ... + # > with forward_with_reduce_compile on the timemix ... + # + # Somehow... + # The reference implement takes: 2281ms + # The chunked version takes: 389ms (chunksize 256) + + # Get the shape + B,T,HC = w.shape + + # Compute the chunks + chunk_size = 256 + chunk_count = SEQ_LEN // chunk_size + chunk_remainder = SEQ_LEN % chunk_size + + # The wkv_state_out + wkv_state_out = wkv_state_in.float() + + # # List of tensor to build + # xlist = [] + xx = xx.clone() + + # Loop over the chunks + for i in range(chunk_count): + sta = i * chunk_size + end = sta + chunk_size + + xx[:,sta:end], wkv_state_out = rwkv7_attn_pytorch_v2_chunk_w_compile_break( + # xpart, wkv_state_out = rwkv7_attn_pytorch_chunk_with_w_compile_break( + r[:,sta:end],w[:,sta:end],k[:,sta:end],v[:,sta:end], + kk[:,sta:end],a[:,sta:end], + BATCH_SIZE, chunk_size, N_HEAD, HEAD_SIZE, + # xx[:,sta:end], wkv_state_out + torch.zeros(B,chunk_size,HC, dtype=xx.dtype, device=xx.device), wkv_state_out + ) + # xlist.append(xpart) + + # Handle the remainder + if chunk_remainder > 0: + sta = chunk_count * chunk_size + end = sta + chunk_remainder + + xx[:,sta:end], wkv_state_out = rwkv7_attn_pytorch_v2_chunk_w_compile_break( + # xpart, wkv_state_out = rwkv7_attn_pytorch_chunk_with_w_compile_break( + r[:,sta:end],w[:,sta:end],k[:,sta:end],v[:,sta:end], + kk[:,sta:end],a[:,sta:end], + BATCH_SIZE, chunk_remainder, N_HEAD, HEAD_SIZE, + # xx[:,sta:end], wkv_state_out, + torch.zeros(B,chunk_remainder,HC, dtype=xx.dtype, device=xx.device), wkv_state_out, + # offset=0, chunk_size=chunk_remainder + ) + # xlist.append(xpart) + + # # Concatenate the list + # xx = torch_cat_no_compiler(xlist, dim=1) + + # Return the output + return xx, wkv_state_out.to(dtype=wkv_state_in.dtype) + +#################################################################################################### +# Working reference copy, that has been validated to be "identical" to the reference implementation +# However this has known pytorch compilation issues, hence the modified chunk wise version is used +# instead for an approximate 5x speed up +#################################################################################################### +@torch.compiler.disable() +def rwkv7_attn_pytorch_ref( + r,w,k,v, kk,a, + BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, + xx, wkv_state_in +): + ######## pure pytorch method + # See: https://github.com/BlinkDL/RWKV-LM/blob/d4c42b2cac10f8f3896ce153e2310dc763662b7a/RWKV-v7/rwkv_v7_demo_fast.py#L238 + ######## + vk_state = wkv_state_in.float() + for t in range(SEQ_LEN): + r_, w_, k_, v_, kk_, a_ = r[:,t], w[:,t], k[:,t], v[:,t], kk[:,t], a[:,t] + vk = v_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ k_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE) + ab = (-kk_).view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ (kk_*a_).view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE) + vk_state = (vk_state * w_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE).float() + vk_state @ ab.float() + vk.float()) + xx[:,t] = ((vk_state.to(dtype=xx.dtype) @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE)) + wkv_state_out = vk_state.to(dtype=wkv_state_in.dtype) + return xx, wkv_state_out +#################################################################################################### + +#################################################################################################### +# Modified reference computation done in fp32, +# with changes made to bring the result closer to the cuda kernel +#################################################################################################### +@torch.compiler.disable() +def rwkv7_attn_pytorch_ref_fp32( + r,w,k,v, kk, iclr, + BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, + xx, wkv_state_in +): + ######## pure pytorch method (modified for fp32) + # See: https://github.com/BlinkDL/RWKV-LM/blob/d4c42b2cac10f8f3896ce153e2310dc763662b7a/RWKV-v7/rwkv_v7_demo_fast.py#L238 + ######## + w = (-w.float().exp()).exp() + + # wkv_state_in = torch.zeros(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE, dtype=torch.float,device=w.device) + vk_state = wkv_state_in.float() + + a = -kk + b = kk * iclr + + for t in range(SEQ_LEN): + r_, w_, k_, v_, a_, b_= r[:,t].float(), w[:,t].float(), k[:,t].float(), v[:,t].float(), a[:,t].float(), b[:,t].float() + # ab = (-kk_).view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ (kk_*a_).view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE) + vk = v_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ k_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE) + vk_state = (vk_state * w_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE).float() + vk_state @ a_.float().view(BATCH_SIZE, N_HEAD,HEAD_SIZE,1) @ b_.view(BATCH_SIZE, N_HEAD,1,HEAD_SIZE) + vk.float()) + xx[:,t] = ((vk_state @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE)).to(dtype=xx.dtype) + + wkv_state_out = vk_state.to(dtype=wkv_state_in.dtype) + return xx, wkv_state_out +#################################################################################################### + +def rwkv7_attn_pytorch_chunk( + r,w,k,v, kk,a, + BATCH_SIZE, N_HEAD, HEAD_SIZE, + xx, wkv_state_in, + offset=0, chunk_size=16 +): + ''' + Chunked version of the RWKV7 attention, for better performance. + If the chunk size is less then 128, this is generally compilable + + This is used by the triton/cuda implement, for the remaining % 16 chunks + ''' + ######## pure pytorch method + # See: https://github.com/BlinkDL/RWKV-LM/blob/d4c42b2cac10f8f3896ce153e2310dc763662b7a/RWKV-v7/rwkv_v7_demo_fast.py#L238 + ######## + vk_state = wkv_state_in.float() + for i in range(chunk_size): + t = offset + i + r_, w_, k_, v_, kk_, a_ = r[:,t], w[:,t], k[:,t], v[:,t], kk[:,t], a[:,t] + vk = v_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ k_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE) + ab = (-kk_).view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1) @ (kk_*a_).view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE) + vk_state = (vk_state * w_.view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE).float() + vk_state @ ab.float() + vk.float()) + xx[:,t] = (vk_state.to(dtype=xx.dtype) @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE) + wkv_state_out = vk_state.to(dtype=wkv_state_in.dtype) + return xx, wkv_state_out + + +def rwkv7_attn_pytorch_v2_chunk_w_compile_break( + r,w,k,v, kk,a, + BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, + xx, wkv_state_in +): + ''' + Chunked version of the RWKV7 attention, for better performance + ''' + full_vk_ = v.view(BATCH_SIZE,SEQ_LEN,N_HEAD, HEAD_SIZE,1) @ k.view(BATCH_SIZE,SEQ_LEN,N_HEAD, 1,HEAD_SIZE) + full_iclr_ = (kk * a).view(BATCH_SIZE,SEQ_LEN,N_HEAD,1,HEAD_SIZE) + full_ab = (-kk).view(BATCH_SIZE,SEQ_LEN,N_HEAD, HEAD_SIZE,1) @ full_iclr_ + + wkv_xx = torch.empty(BATCH_SIZE,SEQ_LEN,N_HEAD,HEAD_SIZE,HEAD_SIZE, dtype=xx.dtype, device=xx.device) + wkv_xx, wkv_state_out = rwkv7_attn_pytorch_v2_inner_w_compile_break( + r,w, + full_vk_, full_ab, + BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, + wkv_xx, wkv_state_in + # xx, wkv_state_in + ) + + # if BATCH_SIZE != 1: + # print("BATCH_SIZE != 1 : ", BATCH_SIZE) + # if SEQ_LEN != 256: + # print("SEQ_LEN != 256 : ", SEQ_LEN) + + # xx[:,t] = ((wkv_state.to(dtype=xx.dtype) @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE)) + xx[:] = (wkv_xx.to(dtype=xx.dtype) @ r.view(BATCH_SIZE,SEQ_LEN,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,SEQ_LEN,N_HEAD*HEAD_SIZE) + + return xx, wkv_state_out + +@torch.compiler.disable() +def rwkv7_attn_pytorch_v2_inner_w_compile_break( + r, w, + full_vk_, full_ab, + BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, + xx, wkv_state_in +): + ''' + Isolated sub-function with no compilation + ''' + return rwkv7_attn_pytorch_v2_inner_jit( + r, w, + full_vk_, full_ab, + BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, + xx, wkv_state_in + ) + +# @torch.compile(fullgraph=True) +@torch.jit.script +def rwkv7_attn_pytorch_v2_inner_jit( + r, w, + full_vk_, full_ab, + BATCH_SIZE:int, SEQ_LEN:int, N_HEAD:int, HEAD_SIZE:int, + wkv_xx, wkv_state_in +): + ''' + Isolated sub-function with JIT + ''' + # wkv_xx = torch.zeros(BATCH_SIZE,SEQ_LEN,N_HEAD,HEAD_SIZE,HEAD_SIZE, dtype=xx.dtype,device=xx.device) + # wkv_state_in = torch.zeros(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE, dtype=torch.float,device=w.device) + wkv_state = wkv_state_in + for t in range(SEQ_LEN): + # r_ = r[:,t] + # w_ = w[:,t] + # vk = full_vk_[:,t].view(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE) + # ab = full_ab[:,t].view(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE) + + wkv_state = (wkv_state * w[:,t].view(BATCH_SIZE,N_HEAD,1,HEAD_SIZE).float() + wkv_state @ full_ab[:,t].view(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE).float() + full_vk_[:,t].view(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE).float()).clone() + wkv_xx[:,t] = wkv_state.to(dtype=w.dtype) + return wkv_xx, wkv_state + # xx[:,t] = ((wkv_state.to(dtype=xx.dtype) @ r_.view(BATCH_SIZE,N_HEAD,HEAD_SIZE,1)).view(BATCH_SIZE,N_HEAD*HEAD_SIZE)) + # return xx, wkv_state + + +# ---------------- +# block/kernel/rwkv7_attn_cuda.py +# ---------------- +import torch, os, time +# from .rwkv7_attn_pytorch import rwkv7_attn_pytorch_chunk + +#################################################################################################### +# Stateless reference implementation +#################################################################################################### + +def load_ref_wkv_cuda_kernel(CHUNK_LEN = 16, HEAD_SIZE = 64): + from torch.utils.cpp_extension import load + + # load_name = f"wind_backstepping_C{HEAD_SIZE}_L{CHUNK_LEN}" + load_name = "wind_backstepping" + load_file = "wkv7" + + # Check if the load_name is already loaded + if load_name in torch.ops: + return torch.ops.wind_backstepping + + # Logging of warning usage for reference implementation + print("[WARNING] Reference CUDA kernel does not support input RWKV state, and is used only for training/validaiton purposes") + + # Get the this script file path, to cmpute the cuda path + this_file_path = os.path.dirname(os.path.abspath(__file__)) + + # # Get the device compute capability + # cuda_device = torch.cuda.current_device() + # compute_capability = torch.cuda.get_device_capability(cuda_device) + # compute_capability_str = f"{compute_capability[0]}{compute_capability[1]}" + # print("[INFO] Using compute capability:", compute_capability_str) + + # Load the kernel, there is some wierd edge condition in compilation, + # that try catching.... and trying again.... sometimes work? + flags = ['-res-usage', f'-D_C_={HEAD_SIZE}', f"-D_CHUNK_LEN_={CHUNK_LEN}", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization"] # + try: + load(name=load_name, sources=[f'{this_file_path}/cuda/{load_file}_cuda.cu', f'{this_file_path}/cuda/{load_file}_op.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags) + except Exception as e: + print("[WARNING] Failed to load the kernel, trying again (sometimes the compiler has wierd race condition)...") + time.sleep(2) # Somehow this works, with minor compilation error, that passes on subsequent reruns + load(name=load_name, sources=[f'{this_file_path}/cuda/{load_file}_cuda.cu', f'{this_file_path}/cuda/{load_file}_op.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags) + + # Return the loaded kernel + return torch.ops.wind_backstepping + +@torch.compiler.disable() +def ref_wkv_cuda_forward(w,q,k,v,z,b, y,s,sa): + torch.ops.wind_backstepping.forward(w,q,k,v,z,b, y,s,sa) + +@torch.compiler.disable() +def ref_wkv_cuda_backward(w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db): + torch.ops.wind_backstepping.backward(w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db) + +class RefCudaWindBackstepping(torch.autograd.Function): + @staticmethod + def forward(ctx, w,q,k,v,z,b): + CHUNK_LEN=16 + B,T,H,C = w.shape + assert T%CHUNK_LEN == 0 + assert all(i.dtype==torch.bfloat16 for i in [w,q,k,v,z,b]) + assert all(i.is_contiguous() for i in [w,q,k,v,z,b]) + y = torch.empty_like(v) + s = torch.empty(B,H,T//CHUNK_LEN,C,C, dtype=torch.float32,device=w.device) + sa = torch.empty(B,T,H,C, dtype=torch.float32,device=w.device) + ref_wkv_cuda_forward(w,q,k,v,z,b, y,s,sa) + ctx.save_for_backward(w,q,k,v,z,b,s,sa) + return y + @staticmethod + def backward(ctx, dy): + assert all(i.dtype==torch.bfloat16 for i in [dy]) + assert all(i.is_contiguous() for i in [dy]) + w,q,k,v,z,b,s,sa = ctx.saved_tensors + dw,dq,dk,dv,dz,db = [torch.empty_like(x) for x in [w,q,k,v,z,b]] + ref_wkv_cuda_backward(w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db) + return dw,dq,dk,dv,dz,db + +@torch.compiler.disable() +def rwkv7_attn_cuda_ref(q,w,k,v, kk,iclr, HEAD_SIZE=64, s0=None): + # Preload the kernel + load_ref_wkv_cuda_kernel() + + # Get the shape + B,T,HC = w.shape + C = HEAD_SIZE + H = HC//C + + # Assert that the chunk is multiple of 16 + assert T % 16 == 0, 'reference cuda, only works in multiple of 16' + + # Initialize the state, if not provided - for compatibility (THE STATE IS NOT UPDATED) + s0 = torch.zeros(B,H,C,C, dtype=torch.float,device=w.device) if s0 is None else s0 + + # Handling the cuda kernel + q,w,k,v,a,b = [i.view(B,T,H,C) for i in [q,w,k,v,(-kk),(kk*iclr)]] + + # Forward with backprop + xx = RefCudaWindBackstepping.apply(w,q,k,v,a,b) + return xx.view(B,T,HC), s0.view(B,H,C,C) + +#################################################################################################### +# State based cuda code +#################################################################################################### + +def load_wkv_cuda_kernel(CHUNK_LEN = 16, HEAD_SIZE = 64): + from torch.utils.cpp_extension import load + + # load_name = f"wind_backstepping_C{HEAD_SIZE}_L{CHUNK_LEN}" + load_name = "state_wind_backstepping" + load_file = "state_wkv7" + + # Check if the load_name is already loaded + if load_name in torch.ops: + return torch.ops.state_wind_backstepping + + # Get the this script file path, to cmpute the cuda path + this_file_path = os.path.dirname(os.path.abspath(__file__)) + + # Load the kernel, there is some wierd edge condition in compilation, + # that try catching.... and trying again.... sometimes work? + flags = ['-res-usage', f'-D_C_={HEAD_SIZE}', f"-D_CHUNK_LEN_={CHUNK_LEN}", "--use_fast_math", "-O3", "-Xptxas -O3", "--extra-device-vectorization"] # + try: + load(name=load_name, sources=[f'{this_file_path}/cuda/{load_file}_cuda.cu', f'{this_file_path}/cuda/{load_file}_op.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags) + except Exception as e: + print("[WARNING] Failed to load the kernel, trying again (sometimes the compiler has wierd race condition)...") + time.sleep(2) # Somehow this works, with minor compilation error, that passes on subsequent reruns + load(name=load_name, sources=[f'{this_file_path}/cuda/{load_file}_cuda.cu', f'{this_file_path}/cuda/{load_file}_op.cpp'], is_python_module=False, verbose=True, extra_cuda_cflags=flags) + + # Return the loaded kernel + return torch.ops.state_wind_backstepping + +@torch.compiler.disable() +def wkv_cuda_forward(state, w,q,k,v,z,b, y,s,sa): + torch.ops.state_wind_backstepping.forward(state, w,q,k,v,z,b, y,s,sa) + +@torch.compiler.disable() +def wkv_cuda_backward(state, w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db): + torch.ops.state_wind_backstepping.backward(state, w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db) + +class CudaWindBackstepping(torch.autograd.Function): + @staticmethod + def forward(ctx, s0, w,q,k,v,z,b): + CHUNK_LEN=16 + B,T,H,C = w.shape + assert T%CHUNK_LEN == 0 + assert all(i.dtype==torch.bfloat16 for i in [w,q,k,v,z,b]) + assert all(i.is_contiguous() for i in [w,q,k,v,z,b]) + y = torch.empty_like(v) + s = torch.empty(B,H,T//CHUNK_LEN,C,C, dtype=torch.float32,device=w.device) + sa = torch.empty(B,T,H,C, dtype=torch.float32,device=w.device) + sOri = s0.clone() + wkv_cuda_forward(s0, w,q,k,v,z,b, y,s,sa) + ctx.save_for_backward(sOri, w,q,k,v,z,b,s,sa) + return y + @staticmethod + def backward(ctx, dy): + assert all(i.dtype==torch.bfloat16 for i in [dy]) + assert all(i.is_contiguous() for i in [dy]) + state,w,q,k,v,z,b,s,sa = ctx.saved_tensors + dS0,dw,dq,dk,dv,dz,db = [torch.empty_like(x) for x in [state,w,q,k,v,z,b]] + wkv_cuda_backward(state, w,q,k,v,z,b, dy,s,sa, dw,dq,dk,dv,dz,db) + return dS0,dw,dq,dk,dv,dz,db + +@torch.compiler.disable() +def rwkv7_attn_cuda(r,w,k,v, kk,iclr, HEAD_SIZE=64, s0=None): + # Preload the kernel + load_wkv_cuda_kernel() + + # Get the shape + B,T,HC = w.shape + + # Check if the chunk is multiple of 16 + chunk_remainder = T % 16 + + # Initialize the state + C = HEAD_SIZE + H = HC//C + + # Initialize the state + s0 = torch.zeros(B,H,C,C, dtype=torch.float,device=w.device) if s0 is None else s0 + sT = s0.to(dtype=torch.float) + + # Optimize the call, if chunk is multiple of 16 + if chunk_remainder == 0: + chunk_xx, chunk_sT = rwkv7_attn_cuda_chunk(r,w,k,v, kk,iclr, HEAD_SIZE, sT) + return chunk_xx, chunk_sT.to(dtype=s0.dtype) + + # Compute the number of chunks + chunks = T // 16 + si = chunks * 16 + + # Get the chunked output + chunk_xx, chunk_sT = rwkv7_attn_cuda_chunk( + r[:,:si],w[:,:si],k[:,:si],v[:,:si], kk[:,:si],iclr[:,:si], + HEAD_SIZE, s0 + ) + + # Get the remainder + remain_xx, last_sT = rwkv7_attn_pytorch_chunk( + r[:,si:],torch.exp(-torch.exp(w[:,si:])),k[:,si:],v[:,si:], kk[:,si:],iclr[:,si:], + B, H, C, + torch.zeros(B, chunk_remainder, HC, device=w.device, dtype=w.dtype), + chunk_sT, chunk_size=chunk_remainder + ) + + # Concatenate and return results + return torch.cat([chunk_xx.to(dtype=w.dtype), remain_xx.to(dtype=w.dtype)], dim=1), last_sT.to(dtype=s0.dtype) + + +def rwkv7_attn_cuda_chunk(r,w,k,v, kk,iclr, HEAD_SIZE=64, s0=None): + ''' + Triton implementation running in blocks of 16 (hardcoded requirement for the kernel) + ''' + B,T,HC = w.shape + assert T % 16 == 0, 'pure cuda, only works in multiple of 16' + C = HEAD_SIZE + H = HC//C + + # Handling the cuda kernel + a,b = -kk, (kk*iclr) + r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,a,b]] + + if s0 is None: + s1 = torch.zeros(B,H,C,C, dtype=torch.float,device=w.device) + else: + s1 = s0.clone() + + # Forward with backprop + xx = CudaWindBackstepping.apply(s1,w,r,k,v,a,b) + return xx.view(B,T,HC), s1.view(B,H,C,C) + + +# ---------------- +# block/kernel/rwkv7_attn_fla.py +# ---------------- +def rwkv7_attn_fla( + r,w,k,v, kk,iclr, + BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, + xx, wkv_state_in +): + from fla.ops.rwkv7.chunk import chunk_rwkv7 + + # Preprocessing the FLA + r,w,k,v,a,b = [i.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1) for i in [r,w,k,v,-kk,(kk*iclr)]] + log_w = -w.float().exp() + + # Run the FLA + output, vk_state = chunk_rwkv7(r=r, log_w=log_w, k=k, v=v, a=a, b=b, initial_state=wkv_state_in.float(), output_final_state=True) + return output, vk_state.to(dtype=wkv_state_in.dtype) + +def rwkv7_attn_fused_reccurent_fla( + r,w,k,v, kk,iclr, + BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, + xx, wkv_state_in +): + from fla.ops.rwkv7.fused_recurrent import fused_recurrent_rwkv7 + + # Preprocessing the FLA + r,w,k,v,a,b = [i.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1) for i in [r,w,k,v,-kk,(kk*iclr)]] + log_w = -w.float().exp() + + # Run the FLA + output, vk_state = fused_recurrent_rwkv7(r=r, log_w=log_w, k=k, v=v, a=a, b=b, initial_state=wkv_state_in.float(), output_final_state=True) + return output, vk_state.to(dtype=wkv_state_in.dtype) + +# ---------------- +# block/kernel/rwkv7_attn_triton.py +# ---------------- +import torch +import torch as th +import triton +import triton.language as tl + +#################################################################################################### +# Triton specific coding (aka mostly songlin & Johan Sokrates Wind stuff) +# +# Copyright (c) 2024, Johan Sokrates Wind, licensed under MIT +# https://github.com/johanwind/wind_rwkv/blob/main/LICENSE +#################################################################################################### + +# ------------------------- +# Triton "smallhead" and "bighead" common code +# ------------------------- + +@triton.jit +def IND3(a,b,c,nb,nc): + return (a*nb+b)*nc+c +@triton.jit +def IND4(a,b,c,d,nb,nc,nd): + return ((a*nb+b)*nc+c)*nd+d +@triton.jit +def IND5(a,b,c,d,e,nb,nc,nd,ne): + return (((a*nb+b)*nc+c)*nd+d)*ne+e + +@triton.jit +def _prod(a,b): return a*b + +# inv(I-A) where A is a strictly lower triangular nxn matrix +@triton.jit +def tri_minv(A, n:tl.constexpr, prec:tl.constexpr): + i = tl.arange(0,n) + prod = (i[None,:]==i[:,None]).to(tl.float32) + for j in range(n-1): + prod += tl_dot(prec, prod, (A*((i[None,:]==j)*(i[:,None]>i[None,:]))).trans()) + return prod.trans() + +@triton.jit +def tl_dot(prec:tl.constexpr, a, b): + if prec == 'fp32': + return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=False) + elif prec == 'tf32': + return tl.dot(a.to(tl.float32),b.trans().to(tl.float32).trans(), allow_tf32=True) + elif prec == 'bf16': + return tl.dot(a.to(tl.bfloat16),b.trans().to(tl.bfloat16).trans(), allow_tf32=True) + else: + tl.static_assert(False) + +# ------------------------- +# Triton "smallhead" code +# ------------------------- + +@triton.jit +def fw_attn_triton(w_,q_,k_,v_,a_,b_, s0_,y_,s_,sT_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr): + bi = tl.program_id(1) + hi = tl.program_id(0) + + i = tl.arange(0,C)[None,:] + state = tl.load(s0_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32) + for t0 in range(T//dT): + t = t0*dT+tl.arange(0,dT)[:,None] + sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + + w = (-sw.exp()).exp() + fw = tl.reduce(w, 0, _prod, keep_dims=True) + incl_pref = tl.cumprod(w,axis=0) + non_incl_pref = incl_pref / w + inv_incl_pref = 1 / incl_pref + + wq = sq * incl_pref + wa = sa * non_incl_pref + kwi = sk * inv_incl_pref + bwi = sb * inv_incl_pref + + mask1 = (t > t.trans()) + ab = tl_dot(prec, wa, bwi.trans()) * mask1 + ak = tl_dot(prec, wa, kwi.trans()) * mask1 + + ab_inv = tri_minv(ab, dT, prec) + + ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans()) + u = tl_dot(prec, ab_inv, ab_u) + mask2 = (t >= t.trans()) + qk = tl_dot(prec, wq, kwi.trans()) * mask2 + qb = tl_dot(prec, wq, bwi.trans()) * mask2 + yy = tl_dot(prec, qk, sv) + tl_dot(prec, qb, u) + tl_dot(prec, wq, state.trans()) + tl.store(y_+IND4(bi,t,hi,i, T,H,C), yy.to(tl.bfloat16)) + + tl.store(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C), state.to(tl.float32)) + state = state * fw + tl_dot(prec, sv.trans(), kwi*fw) + tl_dot(prec, u.trans(), bwi*fw) + tl.store(sT_+IND4(bi,hi,i.trans(),i, H,C,C), state.to(tl.bfloat16)) + +@triton.jit +def bw_attn_triton(w_,q_,k_,v_,a_,b_, dy_,s_,dsT_, dw_,dq_,dk_,dv_,da_,db_,ds0_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr): + bi = tl.program_id(1) + hi = tl.program_id(0) + + i = tl.arange(0,C)[None,:] + dstate = tl.load(dsT_+IND4(bi,hi,i.trans(),i, H,C,C)).to(tl.float32) + + for t0 in range(T//dT-1,-1,-1): + t = t0*dT+tl.arange(0,dT)[:,None] + + state = tl.load(s_+IND5(bi,hi,t0,i.trans(),i, H,T//dT,C,C)).to(tl.float32) + + sw = tl.load(w_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + sq = tl.load(q_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + sk = tl.load(k_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + sa = tl.load(a_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + sb = tl.load(b_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + + dw_fac = -sw.exp() + w = dw_fac.exp() + fw = tl.reduce(w, 0, _prod, keep_dims=True) + incl_pref = tl.cumprod(w,axis=0) + non_incl_pref = incl_pref / w + inv_incl_pref = 1 / incl_pref + + wq = sq * incl_pref + wa = sa * non_incl_pref + kwi = sk * inv_incl_pref + bwi = sb * inv_incl_pref + + mask1 = (t > t.trans()) + ab = tl_dot(prec, wa, bwi.trans()) * mask1 + ak = tl_dot(prec, wa, kwi.trans()) * mask1 + + ab_inv = tri_minv(ab, dT, prec) + + ab_u = tl_dot(prec, ak, sv) + tl_dot(prec, wa, state.trans()) + u = tl_dot(prec, ab_inv, ab_u) + mask2 = (t >= t.trans()) + qk = tl_dot(prec, wq, kwi.trans()) * mask2 + qb = tl_dot(prec, wq, bwi.trans()) * mask2 + + du = tl_dot(prec, qb.trans(), sdy) + tl_dot(prec, bwi*fw, dstate.trans()) + dab_u = tl_dot(prec, ab_inv.trans(), du) + + dv = tl_dot(prec, qk.trans(), sdy) + tl_dot(prec, kwi*fw, dstate.trans()) + tl_dot(prec, ak.trans(), dab_u) + tl.store(dv_+IND4(bi,t,hi,i, T,H,C), dv.to(tl.bfloat16)) + + dab = tl_dot(prec, tl_dot(prec, ab_inv.trans(), du), u.trans()) * mask1 + dak = tl_dot(prec, dab_u, sv.trans()) * mask1 + dab_u_state = tl_dot(prec, dab_u, state) + da = non_incl_pref * (tl_dot(prec, dab, bwi) + tl_dot(prec, dak, kwi) + dab_u_state) + tl.store(da_+IND4(bi,t,hi,i, T,H,C), da.to(tl.bfloat16)) + + dqb = tl_dot(prec, sdy, u.trans()) * mask2 + dqk = tl_dot(prec, sdy, sv.trans()) * mask2 + dy_state = tl_dot(prec, sdy, state) + dq = incl_pref * (tl_dot(prec, dqb, bwi) + tl_dot(prec, dqk, kwi) + dy_state) + tl.store(dq_+IND4(bi,t,hi,i, T,H,C), dq.to(tl.bfloat16)) + + fw_u_dstate = fw * tl_dot(prec, u, dstate) + db = inv_incl_pref * (tl_dot(prec, dab.trans(), wa) + tl_dot(prec, dqb.trans(), wq) + fw_u_dstate) + tl.store(db_+IND4(bi,t,hi,i, T,H,C), db.to(tl.bfloat16)) + + fw_v_dstate = fw * tl_dot(prec, sv, dstate) + dk = inv_incl_pref * (tl_dot(prec, dak.trans(), wa) + tl_dot(prec, dqk.trans(), wq) + fw_v_dstate) + tl.store(dk_+IND4(bi,t,hi,i, T,H,C), dk.to(tl.bfloat16)) + + dw0 = fw * tl.sum(state*dstate, axis=0,keep_dims=True) + for k in range(t0*dT,t0*dT+dT): + lmask = (tk) + A += (tl_dot(prec, dqb*lmask, bwi) + tl_dot(prec, dqk*lmask, kwi)) * wq * (t>=k) + A += (fw_v_dstate*kwi + fw_u_dstate*bwi) * (tk) + dy_state*wq * (t>=k) + dw = tl.sum(A, axis=0,keep_dims=True) + dw0 + + wk = tl.load(w_+IND4(bi,k,hi,i, T,H,C)).to(tl.float32) + dw *= -wk.exp() + tl.store(dw_+IND4(bi,k,hi,i, T,H,C), dw.to(tl.bfloat16)) + + dstate = dstate * fw + tl_dot(prec, sdy.trans(), wq) + tl_dot(prec, dab_u.trans(), wa) + tl.store(ds0_+IND4(bi,hi,i.trans(),i, H,C,C), dstate.to(tl.bfloat16)) + +class TritonRWKV7(th.autograd.Function): + @staticmethod + def forward(ctx, w,q,k,v,z,b,s0, dot_prec): + K = 16 + B,T,H,C = w.shape + s0 = th.zeros(B,H,C,C, dtype=w.dtype,device=w.device) if s0 is None else s0 + y = th.empty_like(v) + sT = th.empty_like(s0) + s = th.zeros(B,H,T//K,C,C, dtype=th.float32,device=w.device) + fw_attn_triton[(H,B)](w,q,k,v,z,b, s0,y,s,sT, B,T,H,C,K, dot_prec) + ctx.dot_prec = dot_prec + ctx.save_for_backward(w,q,k,v,z,b,s) + return y, sT + @staticmethod + def backward(ctx, dy, dsT): + K = 16 + w,q,k,v,z,b,s = ctx.saved_tensors + B,T,H,C = w.shape + dw,dq,dk,dv,dz,db,ds0 = [th.empty_like(x) for x in [w,q,k,v,z,b,dsT]] + bw_attn_triton[(H,B)](w,q,k,v,z,b, dy,s,dsT, dw,dq,dk,dv,dz,db,ds0, B,T,H,C,K, ctx.dot_prec) + return dw,dq,dk,dv,dz,db,ds0,None + +# ------------------------- +# Triton "bighead" code +# ------------------------- + +@triton.autotune(configs=[triton.Config({'dC': dC}, num_stages=1) for dC in [16,32,64]], key=['T','H','C','dT','prec']) +@triton.jit +def fw_attn_triton_bighead(w_,q_,k_,v_,a_,b_, s0_,y_,s_,sT_, wq_,wa_,kwi_,bwi_,fw_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr, dC:tl.constexpr): + tl.static_assert(C%dC == 0) + bi = tl.program_id(1) + hi = tl.program_id(0) + for i0 in range(0,C,dC): + i = i0+tl.arange(0,dC)[None,:] + for j0 in range(0,C,dC): + j = j0+tl.arange(0,dC)[None,:] + state = tl.load(s0_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32) + tl.store(s_+IND5(bi,hi,0,i.trans(),j, H,T//dT,C,C), state.to(tl.float32)) + + for t0 in range(T//dT): + dt = tl.arange(0,dT)[:,None] + t = t0*dT+dt + tl.debug_barrier() + for j0 in range(0,C,dC): + j = j0+tl.arange(0,dC)[None,:] + sw = tl.load(w_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) + sq = tl.load(q_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) + sk = tl.load(k_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) + sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) + sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) + + w = (-sw.exp()).exp() + fw = tl.reduce(w, 0, _prod, keep_dims=True) + incl_pref = tl.cumprod(w,axis=0) + non_incl_pref = incl_pref / w + inv_incl_pref = 1 / incl_pref + + wq = sq * incl_pref + wa = sa * non_incl_pref + kwi = sk * inv_incl_pref + bwi = sb * inv_incl_pref + + tl.store(wq_+IND4(bi,hi,dt,j, H,dT,C), wq.to(tl.float32)) + tl.store(wa_+IND4(bi,hi,dt,j, H,dT,C), wa.to(tl.float32)) + tl.store(kwi_+IND4(bi,hi,dt,j, H,dT,C), kwi.to(tl.float32)) + tl.store(bwi_+IND4(bi,hi,dt,j, H,dT,C), bwi.to(tl.float32)) + tl.store(fw_+IND3(bi,hi,j, H,C), fw.to(tl.float32)) + tl.debug_barrier() + + ab = tl.zeros((dT,dT), tl.float32) + ak = tl.zeros((dT,dT), tl.float32) + qb = tl.zeros((dT,dT), tl.float32) + qk = tl.zeros((dT,dT), tl.float32) + for j0 in range(0,C,dC): + j = j0+tl.arange(0,dC)[None,:] + + wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + + sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) + sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) + + ab += tl_dot(prec, wa, bwi.trans()) + ak += tl_dot(prec, wa, kwi.trans()) + qb += tl_dot(prec, wq, bwi.trans()) + qk += tl_dot(prec, wq, kwi.trans()) + + mask1 = (t > t.trans()) + mask2 = (t >= t.trans()) + ab *= mask1 + ak *= mask1 + qb *= mask2 + qk *= mask2 + + ab_inv = tri_minv(ab, dT, prec) + + for i0 in range(0,C,dC): + i = i0+tl.arange(0,dC)[None,:] + sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + + wa_state = tl.zeros((dT,dC), tl.float32) + wq_state = tl.zeros((dT,dC), tl.float32) + for j0 in range(0,C,dC): + j = j0+tl.arange(0,dC)[None,:] + state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32) + wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + wa_state += tl_dot(prec, wa, state.trans()) + wq_state += tl_dot(prec, wq, state.trans()) + + ab_u = tl_dot(prec, ak, sv) + wa_state + u = tl_dot(prec, ab_inv, ab_u) + yy = tl_dot(prec, qk, sv) + tl_dot(prec, qb, u) + wq_state + tl.store(y_+IND4(bi,t,hi,i, T,H,C), yy.to(tl.bfloat16)) + + for j0 in range(0,C,dC): + j = j0+tl.arange(0,dC)[None,:] + state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32) + kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + fw = tl.load(fw_+IND3(bi,hi,j, H,C)) + + state = state * fw + tl_dot(prec, sv.trans(), kwi*fw) + tl_dot(prec, u.trans(), bwi*fw) + + if t0+1 < T//dT: + tl.store(s_+IND5(bi,hi,t0+1,i.trans(),j, H,T//dT,C,C), state.to(tl.float32)) + else: + tl.store(sT_+IND4(bi,hi,i.trans(),j, H,C,C), state.to(tl.bfloat16)) + + +@triton.autotune(configs=[triton.Config({'dC': dC}, num_stages=1) for dC in [16,32,64]], key=['T','H','C','dT','prec']) +@triton.jit +def bw_attn_triton_bighead(w_,q_,k_,v_,a_,b_, dy_,s_,dsT_,ds_, dw_,dq_,dk_,dv_,da_,db_,ds0_, wq_,wa_,kwi_,bwi_,fw_,u_,dab_u_, B:tl.constexpr,T:tl.constexpr,H:tl.constexpr,C:tl.constexpr,dT:tl.constexpr, prec:tl.constexpr, dC:tl.constexpr): + tl.static_assert(C%dC == 0) + bi = tl.program_id(1) + hi = tl.program_id(0) + + for i0 in range(0,C,dC): + i = i0+tl.arange(0,dC)[None,:] + for j0 in range(0,C,dC): + j = j0+tl.arange(0,dC)[None,:] + dstate = tl.load(dsT_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32) + tl.store(ds_+IND4(bi,hi,i.trans(),j, H,C,C), dstate.to(tl.float32)) + + for t0 in range(T//dT-1,-1,-1): + dt = tl.arange(0,dT)[:,None] + t = t0*dT+dt + tl.debug_barrier() + for j0 in range(0,C,dC): + j = j0+tl.arange(0,dC)[None,:] + sw = tl.load(w_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) + sq = tl.load(q_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) + sk = tl.load(k_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) + sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) + sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) + + w = (-sw.exp()).exp() + fw = tl.reduce(w, 0, _prod, keep_dims=True) + incl_pref = tl.cumprod(w,axis=0) + non_incl_pref = incl_pref / w + inv_incl_pref = 1 / incl_pref + + wq = sq * incl_pref + wa = sa * non_incl_pref + kwi = sk * inv_incl_pref + bwi = sb * inv_incl_pref + + tl.store(wq_+IND4(bi,hi,dt,j, H,dT,C), wq.to(tl.float32)) + tl.store(wa_+IND4(bi,hi,dt,j, H,dT,C), wa.to(tl.float32)) + tl.store(kwi_+IND4(bi,hi,dt,j, H,dT,C), kwi.to(tl.float32)) + tl.store(bwi_+IND4(bi,hi,dt,j, H,dT,C), bwi.to(tl.float32)) + tl.store(fw_+IND3(bi,hi,j, H,C), fw.to(tl.float32)) + tl.debug_barrier() + + ab = tl.zeros((dT,dT), tl.float32) + ak = tl.zeros((dT,dT), tl.float32) + qb = tl.zeros((dT,dT), tl.float32) + qk = tl.zeros((dT,dT), tl.float32) + for j0 in range(0,C,dC): + j = j0+tl.arange(0,dC)[None,:] + + wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + + sa = tl.load(a_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) + sb = tl.load(b_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) + + ab += tl_dot(prec, wa, bwi.trans()) + ak += tl_dot(prec, wa, kwi.trans()) + qb += tl_dot(prec, wq, bwi.trans()) + qk += tl_dot(prec, wq, kwi.trans()) + + mask1 = (t > t.trans()) + mask2 = (t >= t.trans()) + ab *= mask1 + ak *= mask1 + qb *= mask2 + qk *= mask2 + + ab_inv = tri_minv(ab, dT, prec) + + dab = tl.zeros((dT,dT), tl.float32) + dak = tl.zeros((dT,dT), tl.float32) + dqb = tl.zeros((dT,dT), tl.float32) + dqk = tl.zeros((dT,dT), tl.float32) + + tl.debug_barrier() + for i0 in range(0,C,dC): + i = i0+tl.arange(0,dC)[None,:] + wa_state = tl.zeros((dT,dC), tl.float32) + bwi_dw_dstate = tl.zeros((dT,dC), tl.float32) + kwi_dw_dstate = tl.zeros((dT,dC), tl.float32) + for j0 in range(0,C,dC): + j = j0+tl.arange(0,dC)[None,:] + state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32) + dstate = tl.load(ds_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32) + wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + fw = tl.load(fw_+IND3(bi,hi,j, H,C)) + + wa_state += tl_dot(prec, wa, state.trans()) + bwi_dw_dstate += tl_dot(prec, bwi*fw, dstate.trans()) + kwi_dw_dstate += tl_dot(prec, kwi*fw, dstate.trans()) + + sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + + ab_u = tl_dot(prec, ak, sv) + wa_state + u = tl_dot(prec, ab_inv, ab_u) + du = tl_dot(prec, qb.trans(), sdy) + bwi_dw_dstate + dab_u = tl_dot(prec, ab_inv.trans(), du) + + tl.store(u_+IND4(bi,hi,dt,i, H,dT,C), u.to(tl.float32)) + tl.store(dab_u_+IND4(bi,hi,dt,i, H,dT,C), dab_u.to(tl.float32)) + + dv = tl_dot(prec, qk.trans(), sdy) + kwi_dw_dstate + tl_dot(prec, ak.trans(), dab_u) + tl.store(dv_+IND4(bi,t,hi,i, T,H,C), dv.to(tl.bfloat16)) + + dab += tl_dot(prec, dab_u, u.trans()) * mask1 + dak += tl_dot(prec, dab_u, sv.trans()) * mask1 + dqb += tl_dot(prec, sdy, u.trans()) * mask2 + dqk += tl_dot(prec, sdy, sv.trans()) * mask2 + tl.debug_barrier() + + for j0 in range(0,C,dC): + j = j0+tl.arange(0,dC)[None,:] + + dy_state = tl.zeros((dT,dC), tl.float32) + dab_u_state = tl.zeros((dT,dC), tl.float32) + fw_u_dstate = tl.zeros((dT,dC), tl.float32) + fw_v_dstate = tl.zeros((dT,dC), tl.float32) + state_dstate = tl.zeros((1,dC), tl.float32) + + fw = tl.load(fw_+IND3(bi,hi,j, H,C)) + wa = tl.load(wa_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + wq = tl.load(wq_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + for i0 in range(0,C,dC): + i = i0+tl.arange(0,dC)[None,:] + + u = tl.load(u_+IND4(bi,hi,dt,i, H,dT,C)).to(tl.float32) + dab_u = tl.load(dab_u_+IND4(bi,hi,dt,i, H,dT,C)).to(tl.float32) + sv = tl.load(v_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + sdy = tl.load(dy_+IND4(bi,t,hi,i, T,H,C)).to(tl.float32) + + state = tl.load(s_+IND5(bi,hi,t0,i.trans(),j, H,T//dT,C,C)).to(tl.float32) + tl.debug_barrier() + dstate = tl.load(ds_+IND4(bi,hi,i.trans(),j, H,C,C)).to(tl.float32) + tl.debug_barrier() + + dab_u_state += tl_dot(prec, dab_u, state) + fw_u_dstate += fw * tl_dot(prec, u, dstate) + fw_v_dstate += fw * tl_dot(prec, sv, dstate) + dy_state += tl_dot(prec, sdy, state) + + state_dstate += tl.sum(state*dstate, axis=0,keep_dims=True) + + dstate = dstate * fw + tl_dot(prec, sdy.trans(), wq) + tl_dot(prec, dab_u.trans(), wa) + if t0 > 0: + tl.store(ds_+IND4(bi,hi,i.trans(),j, H,C,C), dstate.to(tl.float32)) + else: + tl.store(ds0_+IND4(bi,hi,i.trans(),j, H,C,C), dstate.to(tl.bfloat16)) + + sw = tl.load(w_+IND4(bi,t,hi,j, T,H,C)).to(tl.float32) + w = (-sw.exp()).exp() + incl_pref = tl.cumprod(w,axis=0) + non_incl_pref = incl_pref / w + inv_incl_pref = 1 / incl_pref + + bwi = tl.load(bwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + kwi = tl.load(kwi_+IND4(bi,hi,dt,j, H,dT,C)).to(tl.float32) + + da = non_incl_pref * (tl_dot(prec, dab, bwi) + tl_dot(prec, dak, kwi) + dab_u_state) + tl.store(da_+IND4(bi,t,hi,j, T,H,C), da.to(tl.bfloat16)) + + dq = incl_pref * (tl_dot(prec, dqb, bwi) + tl_dot(prec, dqk, kwi) + dy_state) + tl.store(dq_+IND4(bi,t,hi,j, T,H,C), dq.to(tl.bfloat16)) + + db = inv_incl_pref * (tl_dot(prec, dab.trans(), wa) + tl_dot(prec, dqb.trans(), wq) + fw_u_dstate) + tl.store(db_+IND4(bi,t,hi,j, T,H,C), db.to(tl.bfloat16)) + + dk = inv_incl_pref * (tl_dot(prec, dak.trans(), wa) + tl_dot(prec, dqk.trans(), wq) + fw_v_dstate) + tl.store(dk_+IND4(bi,t,hi,j, T,H,C), dk.to(tl.bfloat16)) + + dw0 = fw * state_dstate + for k in range(t0*dT,t0*dT+dT): + lmask = (tk) + A += (tl_dot(prec, dqb*lmask, bwi) + tl_dot(prec, dqk*lmask, kwi)) * wq * (t>=k) + A += (fw_v_dstate*kwi + fw_u_dstate*bwi) * (tk) + dy_state*wq * (t>=k) + dw = tl.sum(A, axis=0,keep_dims=True) + dw0 + + wk = tl.load(w_+IND4(bi,k,hi,j, T,H,C)).to(tl.float32) + dw *= -wk.exp() + tl.store(dw_+IND4(bi,k,hi,j, T,H,C), dw.to(tl.bfloat16)) + +class TritonBigheadRWKV7(th.autograd.Function): + @staticmethod + def forward(ctx, w,q,k,v,a,b,s0, dot_prec): + K = 16 + B,T,H,C = w.shape + assert T%K == 0 + assert C%16 == 0 + s0 = th.zeros(B,H,C,C, dtype=w.dtype,device=w.device) if s0 is None else s0 + y = th.empty_like(v) + sT = th.empty_like(s0) + s = th.zeros(B,H,T//K,C,C, dtype=th.float32,device=w.device) + wq,wa,kwi,bwi = [th.empty(B,H,K,C, dtype=th.float32,device=w.device) for i in range(4)] + fw = th.empty(B,H,C, dtype=th.float32,device=w.device) + fw_attn_triton_bighead[(H,B)](w,q,k,v,a,b, s0,y,s,sT, wq,wa,kwi,bwi,fw, B,T,H,C,K, dot_prec) + ctx.dot_prec = dot_prec + ctx.save_for_backward(w,q,k,v,a,b,s) + return y, sT + @staticmethod + def backward(ctx, dy, dsT): + K = 16 + w,q,k,v,a,b,s = ctx.saved_tensors + B,T,H,C = w.shape + dw,dq,dk,dv,da,db,ds0 = [th.empty_like(x) for x in [w,q,k,v,a,b,dsT]] + fw = th.empty(B,H,C, dtype=th.float32,device=w.device) + ds = th.empty(B,H,C,C, dtype=th.float32,device=w.device) + wq,wa,kwi,bwi,u,dab_u = [th.empty(B,H,K,C, dtype=th.float32,device=w.device) for i in range(6)] + bw_attn_triton_bighead[(H,B)](w,q,k,v,a,b, dy,s,dsT,ds, dw,dq,dk,dv,da,db,ds0, wq,wa,kwi,bwi,fw,u,dab_u, B,T,H,C,K, ctx.dot_prec) + return dw,dq,dk,dv,da,db,ds0,None + +#################################################################################################### +# Start of pytorch code +#################################################################################################### + +# from .rwkv7_attn_pytorch import rwkv7_attn_pytorch_chunk + +# ------------------------- +# Pytorch "smallhead" code +# ------------------------- + +def rwkv7_attn_triton(r,w,k,v, kk,iclr, HEAD_SIZE=64, dot_prec='fp32', s0=None): + B,T,HC = w.shape + + # Check if the chunk is multiple of 16 + chunk_remainder = T % 16 + + # Optimize the call, if chunk is multiple of 16 + if chunk_remainder == 0: + return rwkv7_attn_triton_chunk(r,w,k,v, kk,iclr, HEAD_SIZE, dot_prec, s0) + + # Initialize the state + C = HEAD_SIZE + H = HC//C + s0 = th.zeros(B,H,C,C, dtype=th.float,device=w.device) if s0 is None else s0 + + # Compute the number of chunks + chunks = T // 16 + si = chunks * 16 + + # Get the chunked output + chunk_xx, chunk_sT = rwkv7_attn_triton_chunk( + r[:,:si],w[:,:si],k[:,:si],v[:,:si], kk[:,:si],iclr[:,:si], + HEAD_SIZE, dot_prec, s0 + ) + + # Get the remainder + remain_xx, last_sT = rwkv7_attn_pytorch_chunk( + r[:,si:],torch.exp(-torch.exp(w[:,si:])),k[:,si:],v[:,si:], kk[:,si:],iclr[:,si:], + B, H, C, torch.zeros(B, chunk_remainder, HC, device=w.device, dtype=w.dtype), + chunk_sT, chunk_size=chunk_remainder + ) + + # Concatenate and return results + return torch.cat([chunk_xx.to(dtype=w.dtype), remain_xx.to(dtype=w.dtype)], dim=1), last_sT.to(dtype=s0.dtype) + + +def rwkv7_attn_triton_chunk(r,w,k,v, kk,iclr, HEAD_SIZE=64, dot_prec='fp32', s0=None): + ''' + Triton implementation running in blocks of 16 (hardcoded requirement for the kernel) + ''' + B,T,HC = w.shape + assert T % 16 == 0, 'pure triton, only works in multiple of 16' + C = HEAD_SIZE + H = HC//C + + # Moving the triton specific operations into the chunk steps + r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,-kk,(kk*iclr)]] + s0 = th.zeros(B,H,C,C, dtype=th.float,device=w.device) if s0 is None else s0 + xx, sT = TritonRWKV7.apply(w,r,k,v,a,b,s0,dot_prec) + return xx.view(B,T,HC), sT + +# ------------------------- +# Pytorch "bighead" code +# ------------------------- + +def rwkv7_attn_triton_bighead(r,w,k,v, kk,iclr, HEAD_SIZE=64, dot_prec='fp32', s0=None): + B,T,HC = w.shape + + # Check if the chunk is multiple of 16 + chunk_remainder = T % 16 + + # Optimize the call, if chunk is multiple of 16 + if chunk_remainder == 0: + return rwkv7_attn_triton_bighead_chunk(r,w,k,v, kk,iclr, HEAD_SIZE, dot_prec, s0) + + # Initialize the state + C = HEAD_SIZE + H = HC//C + s0 = th.zeros(B,H,C,C, dtype=th.float,device=w.device) if s0 is None else s0 + + # Compute the number of chunks + chunks = T // 16 + si = chunks * 16 + + # Get the chunked output + chunk_xx, chunk_sT = rwkv7_attn_triton_bighead_chunk( + r[:,:si],w[:,:si],k[:,:si],v[:,:si], kk[:,:si],iclr[:,:si], + HEAD_SIZE, dot_prec, s0 + ) + + # Get the remainder + remain_xx, last_sT = rwkv7_attn_pytorch_chunk( + r[:,si:],torch.exp(-torch.exp(w[:,si:])),k[:,si:],v[:,si:], kk[:,si:],iclr[:,si:], + B, H, C, torch.zeros(B, chunk_remainder, HC, device=w.device, dtype=w.dtype), + chunk_sT, chunk_size=chunk_remainder + ) + + # Concatenate and return results + return torch.cat([chunk_xx.to(dtype=w.dtype), remain_xx.to(dtype=w.dtype)], dim=1), last_sT.to(dtype=s0.dtype) + + +def rwkv7_attn_triton_bighead_chunk(r,w,k,v, kk,iclr, HEAD_SIZE=64, dot_prec='fp32', s0=None): + ''' + Triton implementation running in blocks of 16 (hardcoded requirement for the kernel) + ''' + B,T,HC = w.shape + assert T % 16 == 0, 'pure triton, only works in multiple of 16' + C = HEAD_SIZE + H = HC//C + + # Moving the triton specific operations into the chunk steps + r,w,k,v,a,b = [i.view(B,T,H,C) for i in [r,w,k,v,-kk,(kk*iclr)]] + s0 = th.zeros(B,H,C,C, dtype=th.float,device=w.device) if s0 is None else s0 + xx, sT = TritonBigheadRWKV7.apply(w,r,k,v,a,b,s0,dot_prec) + return xx.view(B,T,HC), sT + + +# ---------------- +# block/rwkv7_block_config_map.py +# ---------------- +from dataclasses import dataclass +from typing import Optional +from typing import Union +import torch + +@dataclass +class RWKV7BlockConfigMap: + + """Configuration map for RWKV based models""" + # Key properties for the block / model + num_hidden_layers: int + hidden_size: int + + head_size: int = 64 + + # Dropout rate, should only be used in training + dropout_rate: float = 0.0 + + # Implementation backend to use + tmix_backend: str = "auto" + + # --- + # OPTIONAL PROPS + # + # Optional properties which can be derived + # or can be overwritten by the user + # --- + + # Current layer_id of the block + layer_id: Optional[int] = None + + # Device and Data type + device: Union[torch.device, str, None] = None + dtype: Union[torch.dtype, str, None] = None + + # Channel mix / FFN block dimension size + hidden_size_ffn: Optional[int] = None + hidden_size_att: Optional[int] = None + + # # number of heads + # n_head: Optional[int] = None + + # --- + # Initializer, with excess arg ignore + # --- + def __init__( + self, + num_hidden_layers: int, + hidden_size: int, + head_size: int = 64, + dropout_rate: float = 0.0, + tmix_backend: str = "auto", + layer_id: Optional[int] = None, + device: Union[torch.device, str, None] = None, + dtype: Union[torch.dtype, str, None] = None, + hidden_size_ffn: Optional[int] = None, + hidden_size_att: Optional[int] = None, + **kwargs + ) -> None: + ''' + Constructor for the config + ''' + self.num_hidden_layers = num_hidden_layers + self.hidden_size = hidden_size + self.head_size = head_size + self.dropout_rate = dropout_rate + self.tmix_backend = tmix_backend + self.layer_id = layer_id + self.device = device + self.dtype = dtype + self.hidden_size_ffn = hidden_size_ffn + self.hidden_size_att = hidden_size_att + + # --- + # OPTIONAL PROPS FETCHER + # --- + + def get_layer_id(self, fallback:int) -> int: + ''' + Returns the layer id + ''' + if self.layer_id is not None: + return self.layer_id + return fallback + + def get_device(self, fallback:str) -> torch.device: + ''' + Returns the device + ''' + if self.device is not None: + return torch.device(self.device) + if fallback is not None: + return torch.device(fallback) + return torch.get_default_device() + + def get_dtype(self, fallback:str) -> torch.dtype: + ''' + Returns the dtype + ''' + if self.dtype is not None: + key = self.dtype + else: + key = fallback + + # if dtype is already torch.dtype + if isinstance(key, torch.dtype): + return key + + # Get and Check if the dtype is instance of torch.dtype + ret = getattr(torch, key) + assert isinstance(ret, torch.dtype), f"Invalid dtype: {self.dtype}" + return ret + + # --- + + def get_hidden_size_att(self) -> int: + ''' + Returns the dimension of attention + ''' + if self.hidden_size_att is not None: + hidden_size_att = self.hidden_size_att + else: + hidden_size = self.hidden_size + assert hidden_size % 32 == 0, f"hidden_size must be divisible by 32" + hidden_size_att = hidden_size + assert hidden_size_att % 32 == 0, f"hidden_size_att must be divisible by 32 ({hidden_size_att})" + return hidden_size_att + + def get_hidden_size_ffn(self) -> int: + ''' + Returns the dimension of feed forward network + ''' + if self.hidden_size_ffn is not None: + hidden_size_ffn = self.hidden_size_ffn + else: + hidden_size = self.hidden_size + assert hidden_size % 32 == 0, f"hidden_size must be divisible by 32" + hidden_size_ffn = hidden_size * 4 + + assert hidden_size_ffn % 32 == 0, f"hidden_size_ffn must be divisible by 32" + return hidden_size_ffn + + # def get_n_head(self) -> int: + # ''' + # Returns the number of heads + # ''' + # if self.n_head is not None: + # n_head = self.n_head + # else: + # hidden_size_att = self.get_hidden_size_att() + # n_head = self.get_hidden_size_att() // self.head_size + # assert hidden_size_att % n_head == 0 , f"hidden_size_att must be divisible by head_size ({self.head_size})" + # + # return n_head + + # --- + # Duplicator & Normalizer + # --- + + def new_block_config_map(self, **kwargs) -> 'RWKV7BlockConfigMap': + ''' + Returns a new config map with updated values + ''' + + new_dict = {} + for key in RWKV7BlockConfigMap.__dataclass_fields__: + if key in self.__dict__: + new_dict[key] = self.__dict__[key] + new_dict.update(kwargs) + + return RWKV7BlockConfigMap(**new_dict) + + @staticmethod + def normalize(config_map: any) -> 'RWKV7BlockConfigMap': + ''' + Converts either maps, objs or RWKV7BlockConfigMap + ''' + if isinstance(config_map, RWKV7BlockConfigMap): + return config_map + + dict_obj = None + if isinstance(config_map, dict): + dict_obj = config_map + elif hasattr(config_map, '__dict__'): + dict_obj = config_map.__dict__ + + if dict_obj is not None: + # Filter for only valeus in RWKV7BlockConfigMap + new_dict = {} + for key, value in dict_obj.items(): + if key in RWKV7BlockConfigMap.__dataclass_fields__: + new_dict[key] = value + return RWKV7BlockConfigMap(**new_dict) + + raise ValueError(f"Unsupported config_map type: {type(config_map)}") + + +# ---------------- +# block/rwkv7_channel_mix.py +# ---------------- +import torch +from torch import nn +from typing import Union +# from .rwkv7_block_config_map import RWKV7BlockConfigMap + +class RWKV7ChannelMix(torch.nn.Module): + ''' + ChannelMix block for RWKV + This is similar to transformer FFN block + ''' + + def __init__(self, configMap: Union[RWKV7BlockConfigMap, any]): + ''' + Initialize the ChannelMix block. + + Note: this does not initialize the parameter weights itself + which would depend on the `init_parameters()` method + ''' + + super().__init__() + + configMap:RWKV7BlockConfigMap = RWKV7BlockConfigMap.normalize(configMap) + self.configMap = configMap + + # Get various props + hidden_size = configMap.hidden_size + device = configMap.get_device(None) + dtype = configMap.get_dtype('bfloat16') + + # By default, hidden_size_ffn = hidden_size * 4 + hidden_size_ffn = configMap.get_hidden_size_ffn() + + # Build the various params + # --- + self.x_k = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) + self.key = nn.Linear(hidden_size, hidden_size_ffn, bias=False, device=device, dtype=dtype) + self.value = nn.Linear(hidden_size_ffn, hidden_size, bias=False, device=device, dtype=dtype) + + def init_parameters(self): + ''' + Reset the parameters of the block, to an initial state used for training a model from scratch + ''' + + # Get required props + configMap = self.configMap + hidden_size = configMap.hidden_size + num_hidden_layers = configMap.num_hidden_layers + + # Get optional props + layer_id = configMap.get_layer_id(0) + device = configMap.get_device(None) + dtype = configMap.get_dtype('bfloat16') + + # By default, hidden_size_ffn = hidden_size * 4 + hidden_size_ffn = configMap.get_hidden_size_ffn() + + # Reset the various params + # --- + with torch.no_grad(): # fancy init of time_mix + ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0 + ddd = torch.ones(1, 1, hidden_size) + for i in range(hidden_size): + ddd[0, 0, i] = i / hidden_size + self.x_k = nn.Parameter( (1.0 - torch.pow(ddd, ratio_1_to_almost0**4)).to(device, dtype=dtype) ) + + self.key = nn.Linear(hidden_size, hidden_size_ffn, bias=False, device=device, dtype=dtype) + self.value = nn.Linear(hidden_size_ffn, hidden_size, bias=False, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, last_state: torch.Tensor) -> tuple[torch.Tensor,torch.Tensor]: + ''' + Forwarding channel mix given the input tokens and states. + + Given: + - Incoming token embedding size of shape [batch_size, seq_len, embedding_size] + - Incoming channel mix, shift states of the various batches [batch_size, state_size] + + Returns a pair + - Output embedding of shape [batch_size, seq_len, embedding_size] + - Output channel mix, shift state of shape [batch_size, state_size] + ''' + # last_state = last_state.to(self.key.weight.device) + + ########## + ## x070 + ########## + + dxprev = torch.cat((last_state.unsqueeze(1), x[:, :-1]), dim=1) - x + xk = x + dxprev * self.x_k + k = torch.relu( self.key(xk) ) ** 2 + + return self.value(k), x[:,-1] + + @torch.compile(mode="default", fullgraph=True) + def forward_with_default_compile(self, in_x: torch.Tensor, in_state: torch.Tensor, out_x: torch.Tensor, out_state: torch.Tensor) -> tuple[torch.Tensor,torch.Tensor]: + ''' + Compiled varient of the forward function + With no new tensors being created for the output + Useful for static memory allocation optimizations inference + ''' + out_x[:], out_state[:] = self.forward_with_reduce_compile(in_x, in_state) + return out_x, out_state + + @torch.compile(mode="reduce-overhead", fullgraph=True) + def forward_with_reduce_compile(self, in_x: torch.Tensor, in_state: torch.Tensor) -> tuple[torch.Tensor,torch.Tensor]: + ''' + Compiled varient of the forward function + ''' + return self.forward(in_x, in_state) + + def load_from_model_state_dict(self, model_state_dict: dict, layer_id:int, non_blocking:bool=True): + ''' + Given the Full/partial RWKV model weights, loaded via `torch.load` + Setup the the current module weights, using the layer_id + ''' + # Get the current state_dict + current_state_dict = self.state_dict() + + # Iterate each parameter in the state_dict, and compare from the model + for n in current_state_dict: + model_key = f"blocks.{layer_id}.ffn.{n}" + if model_key not in model_state_dict: + continue + + # Copy the values from the state_dict + try: + current_state_dict[n].copy_(model_state_dict[model_key], non_blocking=non_blocking) + except Exception as e: + print(f"[ERROR] loading: {model_key} | model shape: {current_state_dict[n].shape} | weight shape: {model_state_dict[model_key].shape}") + raise e + +# ---------------- +# block/rwkv7_time_mix.py +# ---------------- +import torch, math +from torch import nn +from torch import Tensor +from typing import Union +from torch.nn import functional as F + +# from .rwkv7_block_config_map import RWKV7BlockConfigMap + +# Check for triton package, if GPU is available +triton = None +if torch.cuda.is_available(): + try: + import triton + except ImportError: + triton = None +else: + print("[WARNING] Triton not available, falling back to pytorch mode by default - this is significantly slower") + +class RWKV7TimeMix(torch.nn.Module): + ''' + Time Mix block for RWKV V7 + ''' + + def __init__(self, configMap: Union[RWKV7BlockConfigMap, any]): + ''' + Initialize the TimeMix block. + + Note: this does not initialize the parameter weights itself + which would depend on the `init_parameters()` method + ''' + super().__init__() + + configMap:RWKV7BlockConfigMap = RWKV7BlockConfigMap.normalize(configMap) + self.configMap = configMap + + # Get required props + hidden_size = configMap.hidden_size + # num_hidden_layers = configMap.num_hidden_layers + + # Get the layer id + layer_id = configMap.get_layer_id(0) + self.layer_id = layer_id + + # Get optional props + device = configMap.get_device(None) + dtype = configMap.get_dtype('bfloat16') + + # By default, hidden_size_ffn = hidden_size + hidden_size_att = configMap.get_hidden_size_att() + + # Assert hidden_size == hidden_size_att, until we support different hidden_size and hidden_size_att + assert hidden_size == hidden_size_att, "hidden_size should be equal to hidden_size_att (@TODO: support different hidden_size and hidden_size_att)" + + # Head size settings + head_size = configMap.head_size + self.head_size = head_size + + # Number of heads + n_head = hidden_size_att // head_size + assert hidden_size_att % head_size == 0, "hidden_size_att should be divisible by head_size" + self.n_head = n_head + + # Backend + self.tmix_backend = configMap.tmix_backend + + # Build the various params + # --- + + with torch.no_grad(): + # Note: for some data, you can reduce D_GATE_LORA or even remove this gate + def calc_lora_rank(exponent, multiplier): + return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32 + D_DECAY_LORA = calc_lora_rank(0.5, 1.8) + D_AAA_LORA = calc_lora_rank(0.5, 1.8) + D_MV_LORA = calc_lora_rank(0.5, 1.3) + D_GATE_LORA = calc_lora_rank(0.8, 0.6) + + self.x_r = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) + self.x_w = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) + self.x_k = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) + self.x_v = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) + self.x_a = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) + self.x_g = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) + + self.w0 = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) + self.w1 = nn.Parameter(torch.empty(hidden_size, D_DECAY_LORA, device=device, dtype=dtype)) + self.w2 = nn.Parameter(torch.empty(D_DECAY_LORA, hidden_size, device=device, dtype=dtype)) + + self.a0 = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) + self.a1 = nn.Parameter(torch.empty(hidden_size,D_AAA_LORA, device=device, dtype=dtype)) + self.a2 = nn.Parameter(torch.empty(D_AAA_LORA,hidden_size, device=device, dtype=dtype)) + + if layer_id > 0: + self.v0 = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) + self.v1 = nn.Parameter(torch.empty(hidden_size,D_MV_LORA, device=device, dtype=dtype)) + self.v2 = nn.Parameter(torch.empty(D_MV_LORA,hidden_size, device=device, dtype=dtype)) + + self.g1 = nn.Parameter(torch.empty(hidden_size, D_GATE_LORA, device=device, dtype=dtype)) + self.g2 = nn.Parameter(torch.empty(D_GATE_LORA, hidden_size, device=device, dtype=dtype)) + + self.k_k = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) + self.k_a = nn.Parameter(torch.empty(1,1,hidden_size, device=device, dtype=dtype)) + self.r_k = nn.Parameter(torch.empty(n_head, head_size, device=device, dtype=dtype)) + + self.receptance = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) + self.key = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) + self.value = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) + self.output = nn.Linear(hidden_size_att, hidden_size, bias=False, device=device, dtype=dtype) + self.ln_x = nn.GroupNorm(n_head, hidden_size_att, device=device, dtype=dtype, eps=(1e-5)*head_size) + + def init_parameters(self): + ''' + Reset the parameters of the block, to an initial state used for training a model from scratch + ''' + configMap = self.configMap + + # Get required props + hidden_size = configMap.hidden_size + num_hidden_layers = configMap.num_hidden_layers + + # Get the layer id + layer_id = self.layer_id + + # Get optional props + device = configMap.get_device(None) + dtype = configMap.get_dtype('bfloat16') + + # By default, hidden_size_ffn = hidden_size + hidden_size_att = configMap.get_hidden_size_att() + + # Assert hidden_size == hidden_size_att, until we support different hidden_size and hidden_size_att + assert hidden_size == hidden_size_att, "hidden_size should be equal to hidden_size_att (@TODO: support different hidden_size and hidden_size_att)" + + # Head size settings + head_size = self.head_size + + # Number of heads + n_head = hidden_size_att // head_size + assert hidden_size_att % head_size == 0, "hidden_size_att should be divisible by head_size" + + # Reset the various params + # --- + with torch.no_grad(): + ratio_0_to_1 = layer_id / (num_hidden_layers - 1) # 0 to 1 + ratio_1_to_almost0 = 1.0 - (layer_id / num_hidden_layers) # 1 to ~0 + ddd = torch.ones(1, 1, hidden_size, device=device, dtype=dtype) + for i in range(hidden_size): + ddd[0, 0, i] = i / hidden_size + + # Note: for some data, you can reduce D_GATE_LORA or even remove this gate + def calc_lora_rank(exponent, multiplier): + return max(1, round(hidden_size ** exponent * multiplier / 32)) * 32 + D_DECAY_LORA = calc_lora_rank(0.5, 1.8) + D_AAA_LORA = calc_lora_rank(0.5, 1.8) + D_MV_LORA = calc_lora_rank(0.5, 1.3) + D_GATE_LORA = calc_lora_rank(0.8, 0.6) + + self.x_r.copy_(1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0)) + self.x_w.copy_(1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0)) + self.x_k.copy_(1.0 - (torch.pow(ddd, 0.9 * ratio_1_to_almost0) + 0.4 * ratio_0_to_1)) + self.x_v.copy_(1.0 - (torch.pow(ddd, 0.4 * ratio_1_to_almost0) + 0.6 * ratio_0_to_1)) + self.x_a.copy_(1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0)) + self.x_g.copy_(1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0)) + + def ortho_init(x, scale): + x = x.to(device) + shape = x.shape + if len(shape) == 2: + gain = math.sqrt(shape[0] / shape[1]) if shape[0] > shape[1] else 1 + nn.init.orthogonal_(x, gain=gain * scale) + elif len(shape) == 3: + gain = math.sqrt(shape[1] / shape[2]) if shape[1] > shape[2] else 1 + for i in range(shape[0]): + nn.init.orthogonal_(x[i], gain=gain * scale) + else: + assert False + return x.to(device, dtype=dtype) + + # D_DECAY_LORA = max(32, int(round( (1.8*(hidden_size**0.5)) /32)*32)) + decay_speed = torch.ones(hidden_size, device=device, dtype=dtype) + for n in range(hidden_size): + decay_speed[n] = -7 + 5 * (n / (hidden_size - 1)) ** (0.85 + 1.0 * ratio_0_to_1 ** 0.5) + + self.w0.copy_(decay_speed.reshape(1,1,hidden_size).to(device, dtype=dtype) + 0.5) # !!! 0.5 comes from F.softplus !!! + self.w1.copy_(torch.zeros(hidden_size, D_DECAY_LORA, device=device, dtype=dtype)) + self.w2.copy_(ortho_init(torch.zeros(D_DECAY_LORA, hidden_size), 0.1)) + + # D_AAA_LORA = max(32, int(round( (1.8*(hidden_size**0.5)) /32)*32)) # suggestion + self.a0.copy_(torch.zeros(1,1,hidden_size, device=device, dtype=dtype)) + self.a1.copy_(torch.zeros(hidden_size, D_AAA_LORA, device=device, dtype=dtype)) + self.a2.copy_(ortho_init(torch.zeros(D_AAA_LORA, hidden_size), 0.1)) + + # D_MV_LORA = max(32, int(round( (1.3*(hidden_size**0.5)) /32)*32)) # suggestion + if layer_id > 0: + self.v0.copy_(torch.zeros(1,1,hidden_size, device=device, dtype=dtype)+1.0) + self.v1.copy_(torch.zeros(hidden_size, D_MV_LORA, device=device, dtype=dtype)) + self.v2.copy_(ortho_init(torch.zeros(D_MV_LORA, hidden_size), 0.1)) + + # D_GATE_LORA = max(32, int(round( (0.6*(hidden_size**0.8)) /32)*32)) # suggestion + # Note: for some data, you can reduce D_GATE_LORA or even remove this gate + self.g1.copy_(torch.zeros(hidden_size, D_GATE_LORA, device=device, dtype=dtype)) + self.g2.copy_(ortho_init(torch.zeros(D_GATE_LORA, hidden_size), 0.1)) + + self.k_k.copy_(torch.ones(1,1,hidden_size, device=device, dtype=dtype)*0.85) + self.k_a.copy_(torch.ones(1,1,hidden_size, device=device, dtype=dtype)) + self.r_k.copy_(torch.zeros(n_head,head_size, device=device, dtype=dtype)) + + self.receptance = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) + self.key = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) + self.value = nn.Linear(hidden_size, hidden_size_att, bias=False, device=device, dtype=dtype) + self.output = nn.Linear(hidden_size_att, hidden_size, bias=False, device=device, dtype=dtype) + self.ln_x = nn.GroupNorm(n_head, hidden_size_att, device=device, dtype=dtype, eps=(1e-5)*head_size) + + def forward(self, x:Tensor, shift_state_in:Tensor, wkv_state_in:Tensor, v_first_val:Tensor) -> tuple[Tensor,Tensor,Tensor,Tensor]: + ''' + forwarding time mix given the model weights and the input tokens and states. + + Given: + - Incoming token embedding size of shape [batch_size, seq_len, embedding_size] + - Incoming states containing of shapes: + [batch_size, state_size] ## Token Shift state, + [batch_size, n_head, head_size, head_size] ## WKV state + - Incoming v_first_val of shape [batch_size, seq_len, embedding_size] + + + Returns a pair + - output embedding of shape [batch_size, seq_len, embedding_size] + - output state of shapes: + [batch_size, state_size] ## Token Shift state, + [batch_size, n_head, head_size, head_size] ## WKV state + - output v_first_val of shape [batch_size, seq_len, embedding_size] + + ''' + # Get the sizing + BATCH_SIZE, SEQ_LEN, IN_EMB_SIZE = x.size() + N_HEAD = self.n_head + HEAD_SIZE = self.head_size + + # Ensure wkv_state_in is initialized + if wkv_state_in is None: + wkv_state_in = torch.zeros(BATCH_SIZE,N_HEAD,HEAD_SIZE,HEAD_SIZE, dtype=torch.float,device=w.device) + else: + wkv_state_in = wkv_state_in.clone() + + ########## + ## x070 + ########## + + shift_state_out = x[:, -1] + dxprev = torch.cat((shift_state_in.unsqueeze(1), x[:, :-1]), dim=1) - x + + xr = x + dxprev * self.x_r + xw = x + dxprev * self.x_w + xk = x + dxprev * self.x_k + xv = x + dxprev * self.x_v + xa = x + dxprev * self.x_a + xg = x + dxprev * self.x_g + xx = dxprev + + r = self.receptance(xr) + w = torch.tanh(xw @ self.w1) @ self.w2 + k = self.key(xk) + v = self.value(xv) + g = torch.sigmoid(xg @ self.g1) @ self.g2 + iclr = torch.sigmoid(self.a0 + (xa @ self.a1) @ self.a2) # a is "in-context learning rate" + + kk = F.normalize((k * self.k_k).view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1), dim=-1, p=2.0).view(BATCH_SIZE, SEQ_LEN, IN_EMB_SIZE) + k = k * (1 + (iclr-1) * self.k_a) + + if self.layer_id == 0: + v_first_val = v # store the v of the first layer + else: + v = v + (v_first_val - v) * torch.sigmoid(self.v0 + (xv @ self.v1) @ self.v2) # add value residual + + tmix_backend = self.tmix_backend.lower() + if tmix_backend == "auto": + if triton is None or self.receptance.weight.device.type == "cpu": + tmix_backend = "pytorch" + else: + tmix_backend = "cuda" + + if tmix_backend == "pytorch_ref" or tmix_backend == "pytorch_ref_ori": + # Pure pytorch mode for rwkv attention + # from .kernel.rwkv7_attn_pytorch import rwkv7_attn_pytorch_ref + # Reference minimal compilation version + w = torch.exp(-0.606531 * torch.sigmoid((self.w0 + w).float())) # 0.606531 = exp(-0.5) + xx, wkv_state_out = rwkv7_attn_pytorch_ref(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in) + elif tmix_backend == "pytorch_ref_fp32": + # Pure pytorch mode for rwkv attention + # from .kernel.rwkv7_attn_pytorch import rwkv7_attn_pytorch_ref_fp32 + # Modified to follow the same logic as "cuda" version + # w = torch.exp(-0.606531 * torch.sigmoid((self.w0 + w).float())) # 0.606531 = exp(-0.5) + w = -F.softplus(-(self.w0 + w)) - 0.5 + xx, wkv_state_out = rwkv7_attn_pytorch_ref_fp32(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in) + elif tmix_backend == "pytorch": + # Pure pytorch mode for rwkv attention + # from .kernel.rwkv7_attn_pytorch import rwkv7_attn_pytorch + # Tweaked pytorch compile varient + w = torch.exp(-0.606531 * torch.sigmoid((self.w0 + w).float())) # 0.606531 = exp(-0.5) + xx, wkv_state_out = rwkv7_attn_pytorch(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in) + elif tmix_backend == "triton": + if triton is None: + raise ValueError("Triton not available, unable to load triton kernel") + # from .kernel.rwkv7_attn_triton import rwkv7_attn_triton + w = -F.softplus(-(self.w0 + w)) - 0.5 + xx, wkv_state_out = rwkv7_attn_triton(r, w, k, v, kk, iclr, s0=wkv_state_in) + elif tmix_backend == "triton_bighead": + if triton is None: + raise ValueError("Triton not available, unable to load triton kernel") + # from .kernel.rwkv7_attn_triton import rwkv7_attn_triton_bighead + w = -F.softplus(-(self.w0 + w)) - 0.5 + xx, wkv_state_out = rwkv7_attn_triton_bighead(r, w, k, v, kk, iclr, s0=wkv_state_in) + elif tmix_backend == "cuda_ref": + # Cuda based method for rwkv attention + # from .kernel.rwkv7_attn_cuda import rwkv7_attn_cuda_ref + # Reference cuda version (no state output) + w = -F.softplus(-(self.w0 + w)) - 0.5 + xx, wkv_state_out = rwkv7_attn_cuda_ref(r, w, k, v, kk, iclr, s0=wkv_state_in) + elif tmix_backend == "cuda": + # Cuda based method for rwkv attention + # from .kernel.rwkv7_attn_cuda import rwkv7_attn_cuda + # Modified cuda version (with state output) + w = -F.softplus(-(self.w0 + w)) - 0.5 + xx, wkv_state_out = rwkv7_attn_cuda(r, w, k, v, kk, iclr, s0=wkv_state_in) + elif tmix_backend == "fla": + # FLA based method for rwkv attention + # from .kernel.rwkv7_attn_fla import rwkv7_attn_fla + # FLA runs with the softplus w + w = -F.softplus(-(self.w0 + w)) - 0.5 + xx, wkv_state_out = rwkv7_attn_fla(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in) + elif tmix_backend == "fla_fused" or tmix_backend == "fused_fla": + # FLA based method for rwkv attention + # from .kernel.rwkv7_attn_fla import rwkv7_attn_fused_reccurent_fla + # FLA runs with the softplus w + w = -F.softplus(-(self.w0 + w)) - 0.5 + xx, wkv_state_out = rwkv7_attn_fused_reccurent_fla(r, w, k, v, kk, iclr, BATCH_SIZE, SEQ_LEN, N_HEAD, HEAD_SIZE, xx, wkv_state_in) + else: + raise ValueError(f"Unknown tmix_backend: {tmix_backend}") + + # wkv_state_in normalization of type + if wkv_state_in is not None: + wkv_state_out = wkv_state_out.to(wkv_state_in.dtype) + + ######## cuda-based method + # wkv_state_out = wkv_state_in.clone() + # w = -F.softplus(-(self.w0 + w)) - 0.5 # soft-clamp to (-inf, -0.5) + # xx = RWKV7_OP(wkv_state_out, r, w, k, v, -kk, kk*a) + ######## cuda-based method + + xx = self.ln_x(xx.view(BATCH_SIZE * SEQ_LEN, IN_EMB_SIZE)).view(BATCH_SIZE, SEQ_LEN, IN_EMB_SIZE) + xx = xx + ((r.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1)*k.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1)*self.r_k).sum(dim=-1, keepdim=True) * v.view(BATCH_SIZE,SEQ_LEN,N_HEAD,-1)).view(BATCH_SIZE,SEQ_LEN,IN_EMB_SIZE) + xx = self.output(xx * g) + + return xx, shift_state_out, wkv_state_out, v_first_val + + @torch.compile(mode="default") + def forward_with_default_compile(self, in_x:Tensor, shift_state_in:Tensor, wkv_state_in:Tensor, v_first_val_in:Tensor, out_x:Tensor, shift_state_out:Tensor, wkv_state_out:Tensor, v_first_val_out:Tensor) -> tuple[Tensor,Tensor,Tensor,Tensor]: + ''' + Compiled varient of the forward function + With no new tensors being created for the output + Useful for static memory allocation optimizations inference + ''' + out_x[:], shift_state_out[:], wkv_state_out[:], v_first_val_out[:] = self.forward(in_x, shift_state_in, wkv_state_in, v_first_val_in) + return out_x, shift_state_out, wkv_state_out, v_first_val_out + + @torch.compile(mode="reduce-overhead") + def forward_with_reduce_compile(self, in_x:Tensor, shift_state_in:Tensor, wkv_state_in:Tensor, v_first_val:Tensor) -> tuple[Tensor,Tensor,Tensor,Tensor]: + ''' + Compiled varient of the forward function + With no input tensor being modified. + Useful for reduce-overhead compile mode + ''' + return self.forward(in_x, shift_state_in, wkv_state_in, v_first_val) + + # --------------------------------- + # + # Model state handling + # + # --------------------------------- + + def load_from_model_state_dict(self, model_state_dict: dict, layer_id:int, non_blocking:bool=True): + ''' + Given the Full/partial RWKV model weights, loaded via `torch.load` + Setup the the current module weights, using the layer_id + ''' + # Get the current state_dict + current_state_dict = self.state_dict() + + # Iterate each parameter in the state_dict, and compare from the model + for n in current_state_dict: + model_key = f"blocks.{layer_id}.att.{n}" + if model_key not in model_state_dict: + continue + + # Copy the values from the state_dict + try: + current_state_dict[n].copy_(model_state_dict[model_key], non_blocking=non_blocking) + except Exception as e: + print(f"[ERROR] loading: {model_key} | model shape: {current_state_dict[n].shape} | weight shape: {model_state_dict[model_key].shape}") + raise e + +# ---------------- +# block/rwkv7_layer_block.py +# ---------------- +import torch +from torch import nn +from torch import Tensor +from typing import Union +from torch.nn import functional as F + +# from .rwkv7_block_config_map import RWKV7BlockConfigMap +# from .rwkv7_channel_mix import RWKV7ChannelMix +# from .rwkv7_time_mix import RWKV7TimeMix + +class RWKV7LayerBlock(torch.nn.Module): + ''' + layer block for RWKV V7 + ''' + + def __init__(self, configMap: Union[RWKV7BlockConfigMap, any]): + super().__init__() + + configMap:RWKV7BlockConfigMap = RWKV7BlockConfigMap.normalize(configMap) + self.configMap = configMap + + # Get required props + hidden_size = configMap.hidden_size + device = configMap.get_device(None) + dtype = configMap.get_dtype('bfloat16') + dropout_rate = configMap.dropout_rate + + # Get valid layer_id + layer_id = configMap.get_layer_id(-1) + assert layer_id >= 0, f'layer_id must be >= 0, got {layer_id}' + + # Setup the layernorms, and mixes + self.ln1 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) + self.ln2 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) + + if layer_id == 0: + self.ln0 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) + else: + self.ln0 = nn.Identity(device=device) + + # Setup the time and channel mix + self.att = RWKV7TimeMix(configMap) + self.ffn = RWKV7ChannelMix(configMap) + + # Setup droupout at block level + if dropout_rate > 0.0: + self.drop0 = nn.Dropout(p = dropout_rate,device=device) + self.drop1 = nn.Dropout(p = dropout_rate,device=device) + else: + self.drop0 = nn.Identity(device=device) + self.drop1 = nn.Identity(device=device) + + def init_parameters(self): + ''' + Reset the parameters of the block, to an initial state used for training a model from scratch + ''' + configMap = self.configMap + + # Get required props + hidden_size = configMap.hidden_size + device = configMap.get_device(None) + dtype = configMap.get_dtype('bfloat16') + dropout_rate = configMap.dropout_rate + + # Get valid layer_id + layer_id = configMap.get_layer_id(-1) + assert layer_id >= 0, f'layer_id must be >= 0, got {layer_id}' + + # Redo the Setup for the layernorms, and mixes + self.ln1 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) + self.ln2 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) + + if layer_id == 0: + self.ln0 = nn.LayerNorm(hidden_size, device=device, dtype=dtype) + else: + self.ln0 = nn.Identity(device=device) + + # Call the sub blocks init_parameters + self.att.init_parameters() + self.ffn.init_parameters() + + def forward( + self, x:torch.Tensor, + last_state: tuple[torch.Tensor,torch.Tensor,torch.Tensor], + v_first:torch.Tensor + ) -> tuple[torch.Tensor,tuple[torch.Tensor,torch.Tensor,torch.Tensor],torch.Tensor]: + ''' + Forward the block given the input tokens and the last state + Last state is a tuple of the following + - time mix shift state + - time mix wkv state + - channel mix shift state + + Returns a pair of the output embedding, v_first and the next state + ''' + + # # Ensure everything is in the right device + # x = x.to(self.ln1.weight.device) + # last_state = [ s.to(self.ln1.weight.device) for s in last_state ] + + # Note, that this only applies for layer 0 + ln0_out = self.ln0(x) + + # assert self.ln1(x) is not None + # assert last_state.tmix_shift is not None + # assert last_state.tmix_wkv is not None + + att_out, tmix_shift, tmix_wkv, v_first = self.att( + self.ln1(ln0_out), + last_state[0], # tmix_shift, + last_state[1], # tmix_wkv + v_first + ) + + # x = x + att_out + x = self.drop0(ln0_out + att_out) + + ffn_out, ffn_state = self.ffn( + self.ln2(x), + last_state[2] # cmix_shift, + ) + + # x = x + ffn_out + x = self.drop1(x + ffn_out) + + # # Debugging for NaN + # layer_id = self.configMap.get_layer_id(-1) + # assert torch.isnan(att_out).sum() == 0, f'NaN detected att_out @ layer {layer_id}' + # assert torch.isnan(ffn_out).sum() == 0, f'NaN detected ffn_out @ layer {layer_id}' + # assert torch.isnan(v_first).sum() == 0, f'NaN detected v_first @ layer {layer_id}' + # assert torch.isnan(tmix_shift).sum() == 0, f'NaN detected tmix_shift @ layer {layer_id}' + # assert torch.isnan(tmix_wkv).sum() == 0, f'NaN detected tmix_wkv @ layer {layer_id}' + # assert torch.isnan(ffn_state).sum() == 0, f'NaN detected ffn_state @ layer {layer_id}' + # assert torch.isnan(x).sum() == 0, f'NaN detected block out @ layer {layer_id}' + + return x, (tmix_shift, tmix_wkv, ffn_state), v_first + + @torch.compile(mode="default") + def forward_with_default_compile( + self, + in_x:torch.Tensor, + in_state: tuple[torch.Tensor,torch.Tensor,torch.Tensor], + in_v_first:torch.Tensor, + out_x:torch.Tensor, + out_state: tuple[torch.Tensor,torch.Tensor,torch.Tensor], + out_v_first:torch.Tensor + ) -> tuple[torch.Tensor,tuple[torch.Tensor,torch.Tensor,torch.Tensor],torch.Tensor]: + ''' + Compiled varient of the forward function + With no new tensors being created for the output + Useful for static memory allocation optimizations inference + ''' + out_x[:], tmp_state, out_v_first[:] = self.forward(in_x, in_state, in_v_first) + out_state[0][:], out_state[1][:], out_state[2][:] = tmp_state + return out_x, out_state, out_v_first + + @torch.compile(mode="reduce-overhead") + def forward_with_reduce_compile(self, in_x: torch.Tensor, in_state: tuple[torch.Tensor,torch.Tensor,torch.Tensor], in_v_first:torch.Tensor) -> tuple[torch.Tensor,tuple[torch.Tensor,torch.Tensor,torch.Tensor],torch.Tensor]: + ''' + Compiled varient of the forward function + ''' + return self.forward(in_x, in_state, in_v_first) + + def load_from_model_state_dict(self, model_state_dict:dict, layer_id:int=-1, non_blocking:bool=True): + ''' + Given the Full/partial RWKV model weights, load the block weights accordingly + ''' + if layer_id <= -1: + layer_id = self.configMap.get_layer_id(-1) + assert layer_id >= 0, f'layer_id must be >= 0, got {layer_id}' + + # Get the current state_dict + current_state_dict = self.state_dict() + + # Iterate each parameter in the state_dict, and compare from the model + for n in current_state_dict: + model_key = f"blocks.{layer_id}.{n}" + if model_key not in model_state_dict: + continue + + # Copy the values from the state_dict + try: + current_state_dict[n].copy_(model_state_dict[model_key], non_blocking=non_blocking) + except Exception as e: + print(f"[ERROR] loading: {model_key} | model shape: {current_state_dict[n].shape} | weight shape: {model_state_dict[model_key].shape}") + raise e + +# ---------------- +# model/rwkv7_goose_config_map.py +# ---------------- +from dataclasses import dataclass +from typing import Optional +from typing import Union +import torch + +# from ..block.rwkv7_block_config_map import RWKV7BlockConfigMap + +@dataclass +class RWKV7GooseConfigMap(RWKV7BlockConfigMap): + # This is the world tokenizer size + vocab_size: int = 65536 + init_state_wkv: bool = False + + # --- + # Initializer, with excess arg ignore + # --- + def __init__( + self, + vocab_size: int = 65536, + init_state_wkv: bool = False, + **kwargs + ) -> None: + self.vocab_size = vocab_size + self.init_state_wkv = init_state_wkv + super().__init__(**kwargs) + + @staticmethod + def normalize(config_map: any) -> 'RWKV7GooseConfigMap': + ''' + Converts either maps, objs or RWKV7BlockConfigMap + ''' + if isinstance(config_map, RWKV7GooseConfigMap): + return config_map + + if isinstance(config_map, dict): + return RWKV7GooseConfigMap(**config_map) + + if hasattr(config_map, '__dict__'): + return RWKV7GooseConfigMap(**config_map.__dict__) + + raise ValueError(f"Unsupported config_map type: {type(config_map)}") + + @staticmethod + def from_model_state_dict(state_dict: dict, **kwargs) -> 'RWKV7GooseConfigMap': + ''' + Converts the state dict to the config map + ''' + + # Iterate and count the layers + num_hidden_layers = 0 + for key in state_dict.keys(): + if key.startswith('blocks.'): + idx = key.split('.')[1] + num_hidden_layers = max(num_hidden_layers, int(idx)+1) + + # Enable wkv_state + if 'init_state.0.wkv' in state_dict: + kwargs['init_state_wkv'] = True + + # Initialize the config map, with the configured values + return RWKV7GooseConfigMap( + num_hidden_layers=num_hidden_layers, + hidden_size=state_dict['emb.weight'].shape[1], + vocab_size=state_dict['emb.weight'].shape[0], + # init_state_wkv=hasattr(state_dict, 'init_state.0.wkv'), + + # n_head=state_dict['blocks.0.att.r_k'].shape[0], + head_size=state_dict['blocks.0.att.r_k'].shape[1], + + hidden_size_att=state_dict['blocks.0.att.key.weight'].shape[1], + hidden_size_ffn=state_dict['blocks.0.ffn.key.weight'].shape[0], + + **kwargs + ) + + +# ---------------- +# model/rwkv7_goose_model.py +# ---------------- +import torch +from torch import nn +from torch import Tensor +from typing import Union + +# from .rwkv7_goose_config_map import RWKV7GooseConfigMap +# from ..block.rwkv7_layer_block import RWKV7LayerBlock + +class RWKV7GooseModel(nn.Module): + ''' + RWKV7 Goose Model architecture + Simplified implementation + ''' + + def __init__(self, config: Union[RWKV7GooseConfigMap, any, None] = None): + # Initialize the base class + super().__init__() + + # Normalize the config + configMap:RWKV7GooseConfigMap = RWKV7GooseConfigMap.normalize(config) + self.configMap = configMap + + # Get the required prop + num_hidden_layers = configMap.num_hidden_layers + vocab_size = configMap.vocab_size + device = configMap.get_device(None) + dtype = configMap.get_dtype('bfloat16') + hidden_size = configMap.hidden_size + + # Embedding layer + self.emb = nn.Embedding(vocab_size, hidden_size, device=device, dtype=dtype) + + # main block layers + blockList = [None]*num_hidden_layers + for i in range(num_hidden_layers): + blockList[i] = RWKV7LayerBlock(configMap.new_block_config_map(layer_id=i)) + self.blocks = nn.ModuleList(blockList) + + # ln_out and head + self.ln_out = nn.LayerNorm(hidden_size, device=device, dtype=dtype) + self.head = nn.Linear(hidden_size, vocab_size, bias=False, device=device, dtype=dtype) + + # init state tuning support + if configMap.init_state_wkv: + stateTuneList = [None]*num_hidden_layers + for i in range(num_hidden_layers): + stateTuneList[i] = nn.ParameterDict({ + "wkv": nn.Parameter(torch.zeros(hidden_size // 64, 64, 64, device=device, dtype=dtype)), + }) + self.init_state = nn.ParameterList(stateTuneList) + + def init_parameters(self): + ''' + Reset the parameters of the block, to an initial state used for training a model from scratch + ''' + + # Get the required prop + configMap = self.configMap + num_hidden_layers = configMap.num_hidden_layers + vocab_size = configMap.vocab_size + device = configMap.get_device(None) + dtype = configMap.get_dtype('bfloat16') + hidden_size = configMap.hidden_size + + # Iterate and reset the blocks + for i in range(num_hidden_layers): + self.blocks[i].init_parameters() + + # Reinit the Embedding layer + self.emb = nn.Embedding(vocab_size, hidden_size, device=device, dtype=dtype) + + # Reinit the ln_out and head + self.ln_out = nn.LayerNorm(hidden_size, device=device, dtype=dtype) + if self.head is not None: + self.head = nn.Linear(hidden_size, vocab_size, bias=False, device=device, dtype=dtype) + + # Reinit the init state tuning support + if configMap.init_state_wkv: + stateTuneList = [None]*num_hidden_layers + for i in range(num_hidden_layers): + stateTuneList[i] = nn.ParameterDict({ + "wkv": nn.Parameter(torch.zeros(hidden_size // 64, 64, 64, device=device, dtype=torch.float)), + }) + self.init_state = nn.ParameterList(stateTuneList) + + def load_from_model_state_dict(self, state_dict: dict, non_blocking:bool=True): + ''' + Given the Full/partial RWKV model weights, loaded via `torch.load` + Setup the RWKV_TimeMix model weights, using the layer_id + ''' + for i, block in enumerate(self.blocks): + block.load_from_model_state_dict(state_dict, i, non_blocking=non_blocking) + + self.ln_out.weight.data.copy_(state_dict['ln_out.weight'], non_blocking=True) + self.ln_out.bias.data.copy_(state_dict['ln_out.bias'], non_blocking=True) + self.head.weight.data.copy_(state_dict['head.weight'], non_blocking=True) + self.emb.weight.data.copy_(state_dict['emb.weight'], non_blocking=True) + + ### --- + ### + ### Init state handling + ### + ### --- + + def get_init_state(self, batch_size:int=1, skip_init_state:bool=False) -> list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]: + ''' + Get an initialized copy of the model state, for the given batch size + ''' + # Get required configs + hidden_size = self.configMap.hidden_size + init_state_wkv = self.configMap.init_state_wkv + num_hidden_layers = self.configMap.num_hidden_layers + + # Prepare the initial state + init_state = [ None for i in range(num_hidden_layers) ] + for i in range(num_hidden_layers): + device = self.blocks[i].ln1.weight.data.device + dtype = self.blocks[i].ln1.weight.data.dtype + + # Use the saved init_state if enabled + # TODO: Consider letting the wkv_state dtype be a parameter + wkv_state = torch.zeros(batch_size, hidden_size // 64, 64, 64, device=device, dtype=torch.float) + if init_state_wkv and skip_init_state == False: + init_wkv = self.init_state[i]["wkv"] + for b in range(batch_size): + wkv_state[b][:] = init_wkv + + # Prepare the state + init_state[i] = ( + torch.zeros(batch_size, hidden_size, device=device, dtype=dtype), + wkv_state, + torch.zeros(batch_size, hidden_size, device=device, dtype=dtype) + ) + return init_state + + ### --- + ### + ### Model Forward + ### + ### --- + + def forward( + self, idx:torch.Tensor, + prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]] = None, + ret_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]] = None, + ) -> tuple[torch.Tensor,list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]]: + ''' + Forward the block set, given the input tokens and the last state + Last state is a list of tuple of the following + - time mix shift state + - time mix wkv state + - channel mix shift state + + Returns a pair of the output embedding and the next state + ''' + # Prepare the state, with the batch size + if prv_stateList is None: + prv_stateList = self.get_init_state(idx.shape[0]) + + # If no return state is set, let _forward_internal, set it up + if ret_stateList is None: + ret_stateList = [ None for i in range(self.configMap.num_hidden_layers) ] + return self._forward_internal(idx, prv_stateList, ret_stateList, overwrite_ret_tensor=False) + + # Forward internally + return self._forward_internal(idx, prv_stateList, ret_stateList, overwrite_ret_tensor=True) + + def _forward_block_hook(self, + block:RWKV7LayerBlock, + x_hidden_state:torch.Tensor, + prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]], + v_first:torch.Tensor + ) -> tuple[torch.Tensor,tuple[torch.Tensor,torch.Tensor,torch.Tensor],torch.Tensor]: + ''' + Forward block hook operation, that is easily overridable. + To implement gradient checkpointing for use in various trainers + ''' + x_hidden_state = x_hidden_state.to(block.ln1.weight.device, non_blocking=True) + return block(x_hidden_state, prv_stateList, v_first) + + def _forward_internal( + self, idx:torch.Tensor, + prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]], + ret_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]], + overwrite_ret_tensor:bool=False + ) -> tuple[torch.Tensor,list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]]: + ''' + Internal forward operations, which assumes the state is already initialized + Due to the lack of safety checks, this should not be used directly + ''' + + # Lets get the embedding + idx = idx.to(self.emb.weight.device, non_blocking=True) + x_hidden_state = self.emb(idx) + v_first = None + + # Iterate the block layers, compute the x_hidden_state + if overwrite_ret_tensor: + for i, block in enumerate(self.blocks): + # x_hidden_state, last_block_state, v_first = block(x_hidden_state, prv_stateList[i], v_first) + x_hidden_state, last_block_state, v_first = self._forward_block_hook(block, x_hidden_state, prv_stateList[i], v_first) + ret_stateList[i][0][:] = last_block_state[0] + ret_stateList[i][1][:] = last_block_state[1] + ret_stateList[i][2][:] = last_block_state[2] + else: + ret_stateList = [ None for i in range( len(self.blocks) ) ] + for i, block in enumerate(self.blocks): + # x_hidden_state, ret_sublist, v_first = block(x_hidden_state, prv_stateList[i], v_first) + x_hidden_state, ret_sublist, v_first = self._forward_block_hook(block, x_hidden_state, prv_stateList[i], v_first) + ret_stateList[i] = ret_sublist + + # Final layer norm, and head + x_hidden_state = x_hidden_state.to(self.ln_out.weight.device, non_blocking=True) + x_hidden_state = self.ln_out(x_hidden_state) + x_out = self.head(x_hidden_state) + + # Return the output and the state list + return x_out, ret_stateList + + @torch.compile(mode="default") + def forward_with_default_compile( + self, idx:torch.Tensor, + prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]], + ret_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]], + ) -> tuple[torch.Tensor,list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]]: + ''' + Compiled varient of the forward function + With no new tensors being created for the output + Useful for static memory allocation optimizations inference + ''' + # Forward internally + return self._forward_internal(idx, prv_stateList, ret_stateList, overwrite_ret_tensor=True) + + @torch.compile(mode="reduce-overhead") + def forward_with_reduce_compile( + self, in_idx:torch.Tensor, + prv_stateList:list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]] + ) -> tuple[torch.Tensor,list[tuple[torch.Tensor,torch.Tensor,torch.Tensor]]]: + ''' + Compiled varient of the forward function, requires previous state to be passed + ''' + return self._forward_internal(in_idx, prv_stateList, None, overwrite_ret_tensor=False) +