diff --git "a/rwkv/model.py" "b/rwkv/model.py"
deleted file mode 100644--- "a/rwkv/model.py"
+++ /dev/null
@@ -1,3049 +0,0 @@
-########################################################################################################
-# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
-########################################################################################################
-
-from typing import Optional
-import types, gc, os, time, re, math
-import torch
-import torch.nn as nn
-from torch.nn import functional as F
-
-torch.backends.cudnn.benchmark = True
-torch.backends.cudnn.allow_tf32 = True
-torch.backends.cuda.matmul.allow_tf32 = True
-current_path = os.path.dirname(os.path.abspath(__file__))
-
-########################################################################################################
-
-if os.environ.get("RWKV_JIT_ON") != "0":
-    os.environ["RWKV_JIT_ON"] = "1"
-    MyModule = torch.jit.ScriptModule
-    MyFunction = torch.jit.script_method
-    MyStatic = torch.jit.script
-else:
-    MyModule = torch.nn.Module
-
-    def __nop(ob):
-        return ob
-
-    MyFunction = __nop
-    MyStatic = __nop
-
-if os.environ.get("RWKV_CUDA_ON") == "1":
-    from torch.utils.cpp_extension import load
-
-    try:
-        load(
-            name=f"wkv_cuda",
-            sources=[
-                f"{current_path}/cuda/wrapper.cpp",
-                f"{current_path}/cuda/operators.cu",
-                f"{current_path}/cuda/gemm_fp16_cublas.cpp",
-            ],
-            verbose=True,
-            extra_ldflags=["cublas.lib" if os.name == "nt" else ""],
-            extra_cuda_cflags=[
-                "--use_fast_math",
-                "-O3",
-                "--extra-device-vectorization",
-            ],
-            is_python_module=False,
-        )
-        DISABLE_CUBLAS_GEMM = False
-    except:
-        print(
-            "Failed to build cuBLAS matmul, falling back to torch.matmul. Small model with fp16 will overflow."
-        )
-        load(
-            name=f"wkv_cuda",
-            sources=[
-                f"{current_path}/cuda/wrapper.cpp",
-                f"{current_path}/cuda/operators.cu",
-            ],
-            verbose=True,
-            extra_cuda_cflags=[
-                "--use_fast_math",
-                "-O3",
-                "--extra-device-vectorization",
-            ],
-            extra_cflags=["-DDISABLE_CUBLAS_GEMM"],
-            is_python_module=False,
-        )
-        DISABLE_CUBLAS_GEMM = True
-
-    @MyStatic
-    def cuda_wkv(T: int, C: int, w, u, k, v, aa, bb, pp):
-        assert 1 * C % min(C, 32) == 0
-        assert (
-            k.dtype == v.dtype == torch.float16 or k.dtype == v.dtype == torch.float32
-        )
-        assert w.dtype == u.dtype == aa.dtype == bb.dtype == pp.dtype == torch.float32
-        w = w.contiguous()
-        u = u.contiguous()
-        k = k.contiguous()
-        v = v.contiguous()
-        y = torch.empty(
-            (T, C),
-            device=w.device,
-            memory_format=torch.contiguous_format,
-            dtype=k.dtype,
-        )
-        torch.ops.rwkv.wkv_forward(1, T, C, w, u, k, v, y, aa, bb, pp)
-        return y, aa, bb, pp
-
-    @MyStatic
-    def cuda_mm8_seq(B: int, N: int, M: int, x, w, mx, rx, my, ry):
-        assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype
-        assert x.dtype == torch.float32 or x.dtype == torch.float16
-        assert w.dtype == torch.uint8
-        assert x.shape == (B, N)
-        assert w.shape == (N, M)
-        assert rx.shape == mx.shape == (M,)
-        assert ry.shape == my.shape == (N, 1)
-        y = torch.empty((B, M), device=w.device, dtype=x.dtype)
-        torch.ops.rwkv.mm8_seq(B, N, M, x, w, mx, rx, my, ry, y)
-        return y
-
-    @MyStatic
-    def cuda_mm8_one(N: int, M: int, x, w, mx, rx, my, ry):
-        assert x.dtype == mx.dtype == rx.dtype == my.dtype == ry.dtype
-        assert x.dtype == torch.float32 or x.dtype == torch.float16
-        assert w.dtype == torch.uint8
-        assert x.shape == (N,)
-        assert w.shape == (N, M)
-        assert rx.shape == mx.shape == (M,)
-        assert ry.shape == my.shape == (N, 1)
-        y = torch.zeros((M,), device=w.device, dtype=torch.float32)
-        torch.ops.rwkv.mm8_one(N, M, x, w, mx, rx, my, ry, y)
-        return y.to(dtype=x.dtype)
-
-else:
-    os.environ["RWKV_CUDA_ON"] = "0"
-
-
-@MyStatic
-def torch_mm8_seq(x, w, mx, rx, my, ry):
-    return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
-
-
-@MyStatic
-def torch_mm8_one(x, w, mx, rx, my, ry):
-    return x @ ((w.to(dtype=x.dtype) + 0.5) * ry * rx + my + mx)
-
-
-if os.environ.get("RWKV_CUDA_ON") == "1":
-
-    @MyStatic
-    def mm8_seq(x, w, mx, rx, my, ry):
-        if w.device.type == "cuda" and x.dtype == torch.float16:
-            B, N, M = x.shape[0], w.shape[0], w.shape[1]
-            return cuda_mm8_seq(B, N, M, x, w, mx, rx, my, ry)
-        else:
-            return torch_mm8_seq(x, w, mx, rx, my, ry)
-
-    @MyStatic
-    def mm8_one(x, w, mx, rx, my, ry):
-        if w.device.type == "cuda":
-            N, M = w.shape[0], w.shape[1]
-            return cuda_mm8_one(N, M, x, w, mx, rx, my, ry)
-        else:
-            return torch_mm8_one(x, w, mx, rx, my, ry)
-
-else:
-
-    @MyStatic
-    def mm8_seq(x, w, mx, rx, my, ry):
-        return torch_mm8_seq(x, w, mx, rx, my, ry)
-
-    @MyStatic
-    def mm8_one(x, w, mx, rx, my, ry):
-        return torch_mm8_one(x, w, mx, rx, my, ry)
-
-
-def mm8(
-    x: torch.Tensor,
-    w: torch.Tensor,
-    mx: torch.Tensor,
-    rx: torch.Tensor,
-    my: torch.Tensor,
-    ry: torch.Tensor,
-):
-    if len(x.shape) == 1:
-        return mm8_one(x, w, mx, rx, my, ry)
-    return mm8_seq(x, w, mx, rx, my, ry)
-
-
-def matmul(
-    a,
-    b,
-    mx: Optional[torch.Tensor] = None,
-    rx: Optional[torch.Tensor] = None,
-    my: Optional[torch.Tensor] = None,
-    ry: Optional[torch.Tensor] = None,
-    output_dtype: Optional[torch.dtype] = None,
-) -> torch.Tensor:
-    if output_dtype is None:
-        output_dtype = a.dtype
-    if b.dtype in [torch.float16, torch.bfloat16, torch.float32]:
-        assert a.dtype == b.dtype
-        return matmul_float(a, b, output_dtype=output_dtype)
-    elif b.dtype == torch.uint8:
-        assert mx is not None
-        assert rx is not None
-        assert my is not None
-        assert ry is not None
-        return mm8(a, b, mx, rx, my, ry).to(output_dtype)
-    else:
-        raise ValueError("Unsupported dtype")
-
-
-if os.environ.get("RWKV_CUDA_ON") == "1" and not DISABLE_CUBLAS_GEMM:
-
-    def matmul_float(a, b, output_dtype: Optional[torch.dtype] = None):
-        if output_dtype is None:
-            output_dtype = a.dtype
-        if a.dtype == b.dtype == torch.float16 and a.device.type == "cuda":
-            if len(a.shape) == 1:
-                assert len(b.shape) == 2
-                c = torch.empty((b.shape[-1],), dtype=output_dtype, device=a.device)
-                a = a.unsqueeze(0)
-            else:
-                assert len(a.shape) == len(b.shape)
-                assert len(a.shape) == 2 or len(a.shape) == 3
-                # torch.empty((*a.shape[:-1], b.shape[-1])) doesn't work with jit
-                if len(a.shape) == 2:
-                    c = torch.empty(
-                        (a.shape[0], b.shape[-1]), dtype=output_dtype, device=a.device
-                    )
-                else:
-                    c = torch.empty(
-                        (a.shape[0], a.shape[1], b.shape[-1]),
-                        dtype=output_dtype,
-                        device=a.device,
-                    )
-            torch.ops.rwkv.gemm_fp16_cublas(a, b, c)
-            return c
-        else:
-            return (a @ b).to(output_dtype)
-
-else:
-
-    def matmul_float(a, b, output_dtype: Optional[torch.dtype] = None):
-        return (a @ b).to(output_dtype)
-
-
-if os.environ.get("RWKV_DML_ON") == "1":
-    import torch_directml
-
-    print("PyTorch with DirectML Enabled")
-
-if os.environ.get("RWKV_V7_ON") == "1":
-
-    print(f'\n### RWKV-7 "Goose" enabled ###\n')
-
-    torch.backends.cudnn.benchmark = True
-    torch.backends.cudnn.allow_tf32 = True
-    torch.backends.cuda.matmul.allow_tf32 = True
-    # torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
-    # torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True
-    torch._C._jit_set_autocast_mode(False)
-
-    MyModule = torch.jit.ScriptModule
-    MyFunction = torch.jit.script_method
-    MyStatic = torch.jit.script
-    from typing import List
-
-    DTYPE = None
-    DEVICE = None
-    HEAD_SIZE = 64
-
-    if os.environ.get("RWKV_CUDA_ON") == "1":
-        from torch.utils.cpp_extension import load
-
-        load(
-            name="wkv7s",
-            sources=[
-                f"{current_path}/cuda/rwkv7_op.cpp",
-                f"{current_path}/cuda/rwkv7.cu",
-            ],
-            is_python_module=False,
-            verbose=True,
-            extra_cuda_cflags=[
-                "-res-usage",
-                "--use_fast_math",
-                "-O3",
-                "-Xptxas -O3",
-                "--extra-device-vectorization",
-                f"-D_N_={HEAD_SIZE}",
-            ],
-        )
-
-        class WKV_7(torch.autograd.Function):
-            @staticmethod
-            def forward(ctx, state, r, w, k, v, a, b):
-                with torch.no_grad():
-                    T, C = r.size()
-                    H = C // HEAD_SIZE
-                    N = HEAD_SIZE
-                    assert HEAD_SIZE == C // H
-                    assert all(x.dtype == DTYPE for x in [r, w, k, v, a, b])
-                    assert all(x.is_contiguous() for x in [r, w, k, v, a, b])
-                    y = torch.empty(
-                        (T, C),
-                        device=DEVICE,
-                        dtype=r.dtype,
-                        requires_grad=False,
-                        memory_format=torch.contiguous_format,
-                    )
-
-                    if DTYPE == torch.float16:
-                        torch.ops.wkv7s.forward_fp16(
-                            1, T, C, H, state, r, w, k, v, a, b, y
-                        )
-                    elif DTYPE == torch.bfloat16:
-                        torch.ops.wkv7s.forward_bf16(
-                            1, T, C, H, state, r, w, k, v, a, b, y
-                        )
-                    elif DTYPE == torch.float32:
-                        torch.ops.wkv7s.forward_fp32(
-                            1, T, C, H, state, r, w, k, v, a, b, y
-                        )
-
-                    return y
-
-        def RWKV7_OP(state, r, w, k, v, a, b):
-            return WKV_7.apply(state, r, w, k, v, a, b)
-
-    ########################################################################################################
-
-    class RWKV_x070(MyModule):
-        def __init__(self, model, strategy):
-            global DTYPE, DEVICE
-            super().__init__()
-            self.eval()
-            args = types.SimpleNamespace()
-            self.args = args
-            args.MODEL_NAME = model
-
-            print(f"Loading {model} ({strategy})\n")
-
-            ss = strategy.split(" ")
-            DEVICE = ss[0]
-            if ss[1] == "fp16":
-                DTYPE = torch.half
-            elif ss[1] == "fp32":
-                DTYPE = torch.float32
-            elif ss[1] == "bf16":
-                DTYPE = torch.bfloat16
-            else:
-                assert (
-                    False
-                ), "currently rwkv7 strategy must be: cuda/cpu fp16/fp32/bf16"
-
-            self.z = torch.load(args.MODEL_NAME + ".pth", map_location=DEVICE)
-            z = self.z
-
-            self.n_head, self.head_size = z["blocks.0.att.r_k"].shape
-            args.head_size = self.head_size
-            args.vocab_size, args.n_embd = z["emb.weight"].shape
-
-            args.n_layer = 0
-            keys = list(z.keys())
-            for k in keys:
-                layer_id = int(k.split(".")[1]) if ("blocks." in k) else 0
-                args.n_layer = max(args.n_layer, layer_id + 1)
-                if (
-                    "key.weight" in k
-                    or "value.weight" in k
-                    or "receptance.weight" in k
-                    or "output.weight" in k
-                    or "head.weight" in k
-                ):
-                    z[k] = z[k].t()
-                z[k] = z[k].squeeze().to(dtype=DTYPE)
-                if k.endswith("att.r_k"):
-                    z[k] = z[k].flatten()
-            self.n_embd = args.n_embd
-            self.n_layer = args.n_layer
-
-            z["emb.weight"] = F.layer_norm(
-                z["emb.weight"],
-                (args.n_embd,),
-                weight=z["blocks.0.ln0.weight"],
-                bias=z["blocks.0.ln0.bias"],
-            )
-            torch.cuda.empty_cache()
-            z["blocks.0.att.v0"] = z["blocks.0.att.a0"]  # actually ignored
-            z["blocks.0.att.v1"] = z["blocks.0.att.a1"]  # actually ignored
-            z["blocks.0.att.v2"] = z["blocks.0.att.a2"]  # actually ignored
-
-        def forward(self, idx, state, full_output=False):
-            if state == None:
-                state = [None for _ in range(self.args.n_layer * 3)]
-                for i in range(
-                    self.args.n_layer
-                ):  # state: 0=att_x_prev 1=att_kv 2=ffn_x_prev
-                    state[i * 3 + 0] = torch.zeros(
-                        self.args.n_embd,
-                        dtype=DTYPE,
-                        requires_grad=False,
-                        device=DEVICE,
-                    )
-                    state[i * 3 + 1] = torch.zeros(
-                        (
-                            self.args.n_embd // self.args.head_size,
-                            self.args.head_size,
-                            self.args.head_size,
-                        ),
-                        dtype=torch.float,
-                        requires_grad=False,
-                        device=DEVICE,
-                    )
-                    state[i * 3 + 2] = torch.zeros(
-                        self.args.n_embd,
-                        dtype=DTYPE,
-                        requires_grad=False,
-                        device=DEVICE,
-                    )
-
-            if type(idx) is list:
-                if len(idx) > 1:
-                    return self.forward_seq(idx, state, full_output)
-                else:
-                    return self.forward_one(idx[0], state)
-            else:
-                return self.forward_one(idx, state)
-
-        @MyFunction
-        def forward_one(self, idx: int, state: List[torch.Tensor]):
-            with torch.no_grad():
-                z = self.z
-                x = z["emb.weight"][idx]
-
-                v_first = torch.empty_like(x)
-                for i in range(self.n_layer):
-                    bbb = f"blocks.{i}."
-                    att = f"blocks.{i}.att."
-                    ffn = f"blocks.{i}.ffn."
-
-                    xx = F.layer_norm(
-                        x,
-                        (self.n_embd,),
-                        weight=z[bbb + "ln1.weight"],
-                        bias=z[bbb + "ln1.bias"],
-                    )
-
-                    xx, state[i * 3 + 0], state[i * 3 + 1], v_first = (
-                        RWKV_x070_TMix_one(
-                            i,
-                            self.n_head,
-                            self.head_size,
-                            xx,
-                            state[i * 3 + 0],
-                            v_first,
-                            state[i * 3 + 1],
-                            z[att + "x_r"],
-                            z[att + "x_w"],
-                            z[att + "x_k"],
-                            z[att + "x_v"],
-                            z[att + "x_a"],
-                            z[att + "x_g"],
-                            z[att + "w0"],
-                            z[att + "w1"],
-                            z[att + "w2"],
-                            z[att + "a0"],
-                            z[att + "a1"],
-                            z[att + "a2"],
-                            z[att + "v0"],
-                            z[att + "v1"],
-                            z[att + "v2"],
-                            z[att + "g1"],
-                            z[att + "g2"],
-                            z[att + "k_k"],
-                            z[att + "k_a"],
-                            z[att + "r_k"],
-                            z[att + "receptance.weight"],
-                            z[att + "key.weight"],
-                            z[att + "value.weight"],
-                            z[att + "output.weight"],
-                            z[att + "ln_x.weight"],
-                            z[att + "ln_x.bias"],
-                        )
-                    )
-                    x = x + xx
-
-                    xx = F.layer_norm(
-                        x,
-                        (self.n_embd,),
-                        weight=z[bbb + "ln2.weight"],
-                        bias=z[bbb + "ln2.bias"],
-                    )
-
-                    xx, state[i * 3 + 2] = RWKV_x070_CMix_one(
-                        xx,
-                        state[i * 3 + 2],
-                        z[ffn + "x_k"],
-                        z[ffn + "key.weight"],
-                        z[ffn + "value.weight"],
-                    )
-                    x = x + xx
-
-                    # if math.isnan(torch.min(x).item()): print(idx, i)
-
-                x = F.layer_norm(
-                    x, (self.n_embd,), weight=z["ln_out.weight"], bias=z["ln_out.bias"]
-                )
-                x = x @ z["head.weight"]
-                return x, state
-
-        @MyFunction
-        def forward_seq(
-            self, idx: List[int], state: List[torch.Tensor], full_output: bool = False
-        ):
-            with torch.no_grad():
-                z = self.z
-                x = z["emb.weight"][idx]
-
-                v_first = torch.empty_like(x)
-                for i in range(self.n_layer):
-                    bbb = f"blocks.{i}."
-                    att = f"blocks.{i}.att."
-                    ffn = f"blocks.{i}.ffn."
-
-                    xx = F.layer_norm(
-                        x,
-                        (self.n_embd,),
-                        weight=z[bbb + "ln1.weight"],
-                        bias=z[bbb + "ln1.bias"],
-                    )
-
-                    xx, state[i * 3 + 0], state[i * 3 + 1], v_first = (
-                        RWKV_x070_TMix_seq(
-                            i,
-                            self.n_head,
-                            self.head_size,
-                            xx,
-                            state[i * 3 + 0],
-                            v_first,
-                            state[i * 3 + 1],
-                            z[att + "x_r"],
-                            z[att + "x_w"],
-                            z[att + "x_k"],
-                            z[att + "x_v"],
-                            z[att + "x_a"],
-                            z[att + "x_g"],
-                            z[att + "w0"],
-                            z[att + "w1"],
-                            z[att + "w2"],
-                            z[att + "a0"],
-                            z[att + "a1"],
-                            z[att + "a2"],
-                            z[att + "v0"],
-                            z[att + "v1"],
-                            z[att + "v2"],
-                            z[att + "g1"],
-                            z[att + "g2"],
-                            z[att + "k_k"],
-                            z[att + "k_a"],
-                            z[att + "r_k"],
-                            z[att + "receptance.weight"],
-                            z[att + "key.weight"],
-                            z[att + "value.weight"],
-                            z[att + "output.weight"],
-                            z[att + "ln_x.weight"],
-                            z[att + "ln_x.bias"],
-                        )
-                    )
-                    x = x + xx
-
-                    xx = F.layer_norm(
-                        x,
-                        (self.n_embd,),
-                        weight=z[bbb + "ln2.weight"],
-                        bias=z[bbb + "ln2.bias"],
-                    )
-
-                    xx, state[i * 3 + 2] = RWKV_x070_CMix_seq(
-                        xx,
-                        state[i * 3 + 2],
-                        z[ffn + "x_k"],
-                        z[ffn + "key.weight"],
-                        z[ffn + "value.weight"],
-                    )
-                    x = x + xx
-
-                if not full_output:
-                    x = x[-1, :]
-                x = F.layer_norm(
-                    x, (self.n_embd,), weight=z["ln_out.weight"], bias=z["ln_out.bias"]
-                )
-                x = x @ z["head.weight"]
-                return x, state
-
-    ########################################################################################################
-
-    @MyStatic
-    def RWKV_x070_TMix_one(
-        layer_id: int,
-        H: int,
-        N: int,
-        x,
-        x_prev,
-        v_first,
-        state,
-        x_r,
-        x_w,
-        x_k,
-        x_v,
-        x_a,
-        x_g,
-        w0,
-        w1,
-        w2,
-        a0,
-        a1,
-        a2,
-        v0,
-        v1,
-        v2,
-        g1,
-        g2,
-        k_k,
-        k_a,
-        r_k,
-        R_,
-        K_,
-        V_,
-        O_,
-        ln_w,
-        ln_b,
-    ):
-        xx = x_prev - x
-        xr, xw, xk, xv, xa, xg = (
-            x + xx * x_r,
-            x + xx * x_w,
-            x + xx * x_k,
-            x + xx * x_v,
-            x + xx * x_a,
-            x + xx * x_g,
-        )
-
-        r = xr @ R_
-        w = torch.tanh(xw @ w1) @ w2
-        k = xk @ K_
-        v = xv @ V_
-        a = torch.sigmoid(a0 + (xa @ a1) @ a2)
-        g = torch.sigmoid(xg @ g1) @ g2
-
-        kk = torch.nn.functional.normalize((k * k_k).view(H, N), dim=-1, p=2.0).view(
-            H * N
-        )
-        k = k * (1 + (a - 1) * k_a)
-        if layer_id == 0:
-            v_first = v
-        else:
-            v = v + (v_first - v) * torch.sigmoid(v0 + (xv @ v1) @ v2)
-        w = torch.exp(
-            -0.606531 * torch.sigmoid((w0 + w).float())
-        )  # 0.606531 = exp(-0.5)
-
-        vk = v.view(H, N, 1) @ k.view(H, 1, N)
-        ab = (-kk).view(H, N, 1) @ (kk * a).view(H, 1, N)
-        state = state * w.view(H, 1, N) + state @ ab.float() + vk.float()
-        xx = state.to(dtype=x.dtype) @ r.view(H, N, 1)
-
-        xx = torch.nn.functional.group_norm(
-            xx.view(1, H * N), num_groups=H, weight=ln_w, bias=ln_b, eps=64e-5
-        ).view(H * N)
-        xx = xx + (
-            (r * k * r_k).view(H, N).sum(dim=-1, keepdim=True) * v.view(H, N)
-        ).view(H * N)
-        return (xx * g) @ O_, x, state, v_first
-
-    if os.environ.get("RWKV_CUDA_ON") == "1":
-
-        @MyStatic
-        def RWKV_x070_TMix_seq(
-            layer_id: int,
-            H: int,
-            N: int,
-            x,
-            x_prev,
-            v_first,
-            state,
-            x_r,
-            x_w,
-            x_k,
-            x_v,
-            x_a,
-            x_g,
-            w0,
-            w1,
-            w2,
-            a0,
-            a1,
-            a2,
-            v0,
-            v1,
-            v2,
-            g1,
-            g2,
-            k_k,
-            k_a,
-            r_k,
-            R_,
-            K_,
-            V_,
-            O_,
-            ln_w,
-            ln_b,
-        ):
-            T = x.shape[0]
-            xx = torch.cat((x_prev.unsqueeze(0), x[:-1, :])) - x
-            xr, xw, xk, xv, xa, xg = (
-                x + xx * x_r,
-                x + xx * x_w,
-                x + xx * x_k,
-                x + xx * x_v,
-                x + xx * x_a,
-                x + xx * x_g,
-            )
-
-            r = xr @ R_
-            w = torch.tanh(xw @ w1) @ w2
-            k = xk @ K_
-            v = xv @ V_
-            a = torch.sigmoid(a0 + (xa @ a1) @ a2)
-            g = torch.sigmoid(xg @ g1) @ g2
-
-            kk = torch.nn.functional.normalize(
-                (k * k_k).view(T, H, N), dim=-1, p=2.0
-            ).view(T, H * N)
-            k = k * (1 + (a - 1) * k_a)
-            if layer_id == 0:
-                v_first = v
-            else:
-                v = v + (v_first - v) * torch.sigmoid(v0 + (xv @ v1) @ v2)
-
-            w = -torch.nn.functional.softplus(-(w0 + w)) - 0.5
-            xx = RWKV7_OP(state, r, w, k, v, -kk, kk * a)
-
-            xx = torch.nn.functional.group_norm(
-                xx.view(T, H * N), num_groups=H, weight=ln_w, bias=ln_b, eps=64e-5
-            ).view(T, H * N)
-            xx = xx + (
-                (r * k * r_k).view(T, H, N).sum(dim=-1, keepdim=True) * v.view(T, H, N)
-            ).view(T, H * N)
-            return (xx * g) @ O_, x[-1, :], state, v_first
-
-    else:
-
-        @MyStatic
-        def RWKV_x070_TMix_seq(
-            layer_id: int,
-            H: int,
-            N: int,
-            x,
-            x_prev,
-            v_first,
-            state,
-            x_r,
-            x_w,
-            x_k,
-            x_v,
-            x_a,
-            x_g,
-            w0,
-            w1,
-            w2,
-            a0,
-            a1,
-            a2,
-            v0,
-            v1,
-            v2,
-            g1,
-            g2,
-            k_k,
-            k_a,
-            r_k,
-            R_,
-            K_,
-            V_,
-            O_,
-            ln_w,
-            ln_b,
-        ):
-            T = x.shape[0]
-            xx = torch.cat((x_prev.unsqueeze(0), x[:-1, :])) - x
-            xr, xw, xk, xv, xa, xg = (
-                x + xx * x_r,
-                x + xx * x_w,
-                x + xx * x_k,
-                x + xx * x_v,
-                x + xx * x_a,
-                x + xx * x_g,
-            )
-
-            r = xr @ R_
-            w = torch.tanh(xw @ w1) @ w2
-            k = xk @ K_
-            v = xv @ V_
-            a = torch.sigmoid(a0 + (xa @ a1) @ a2)
-            g = torch.sigmoid(xg @ g1) @ g2
-
-            kk = torch.nn.functional.normalize(
-                (k * k_k).view(T, H, N), dim=-1, p=2.0
-            ).view(T, H * N)
-            k = k * (1 + (a - 1) * k_a)
-            if layer_id == 0:
-                v_first = v
-            else:
-                v = v + (v_first - v) * torch.sigmoid(v0 + (xv @ v1) @ v2)
-
-            w = torch.exp(
-                -0.606531 * torch.sigmoid((w0 + w).float())
-            )  # 0.606531 = exp(-0.5)
-            for t in range(T):
-                r_, w_, k_, v_, kk_, a_ = r[t], w[t], k[t], v[t], kk[t], a[t]
-                vk = v_.view(H, N, 1) @ k_.view(H, 1, N)
-                ab = (-kk_).view(H, N, 1) @ (kk_ * a_).view(H, 1, N)
-                state = state * w_.view(H, 1, N) + state @ ab.float() + vk.float()
-                xx[t] = (state.to(dtype=x.dtype) @ r_.view(H, N, 1)).view(H * N)
-
-            xx = torch.nn.functional.group_norm(
-                xx.view(T, H * N), num_groups=H, weight=ln_w, bias=ln_b, eps=64e-5
-            ).view(T, H * N)
-            xx = xx + (
-                (r * k * r_k).view(T, H, N).sum(dim=-1, keepdim=True) * v.view(T, H, N)
-            ).view(T, H * N)
-            return (xx * g) @ O_, x[-1, :], state, v_first
-
-    ########################################################################################################
-
-    @MyStatic
-    def RWKV_x070_CMix_one(x, x_prev, x_k, K_, V_):
-        xx = x_prev - x
-        k = x + xx * x_k
-        k = torch.relu(k @ K_) ** 2
-        return k @ V_, x
-
-    @MyStatic
-    def RWKV_x070_CMix_seq(x, x_prev, x_k, K_, V_):
-        xx = torch.cat((x_prev.unsqueeze(0), x[:-1, :])) - x
-        k = x + xx * x_k
-        k = torch.relu(k @ K_) ** 2
-        return k @ V_, x[-1, :]
-
-
-########################################################################################################
-
-
-class RWKV(MyModule):
-    def __init__(self, model, strategy, verbose=True, convert_and_save_and_exit=None):
-        super().__init__()
-        if verbose:
-            prxxx = lambda *args, **kwargs: print(*args, **kwargs)
-        else:
-            prxxx = lambda *args, **kwargs: None
-
-        STRATEGY_REGEX = r"^(?:(?:^|->) *(?:cuda(?::[\d]+)?|cpu|mps|dml) (?:fp(?:16|32)|bf16)(?:i8|i4|i3)?(?: \*[\d]+\+?)? *)+$"
-        if not re.match(STRATEGY_REGEX, strategy):
-            raise ValueError(
-                "Invalid strategy. Please read https://pypi.org/project/rwkv/"
-            )
-
-        strategy = ("->".join([x.strip() for x in strategy.split("->")])).replace(
-            "->", " -> "
-        )
-        self.args = types.SimpleNamespace()
-        args = self.args
-        args.MODEL_NAME = model
-        args.strategy_string = strategy
-
-        # Rescale for fp16 mode: set x = x/2 every X layer (to avoid fp16 overflow)
-        try:
-            self.RESCALE_LAYER = int(
-                os.environ["RWKV_RESCALE_LAYER"]
-            )  # !!! NOTE: SEEMS YOU SHOULD SET IT TO 999 (disable) FOR RWKV-MUSIC MODELS !!!
-        except:
-            self.RESCALE_LAYER = 6 if "fp16" in strategy else 0
-        prxxx(
-            f'RWKV_JIT_ON {os.environ["RWKV_JIT_ON"]} RWKV_CUDA_ON {os.environ["RWKV_CUDA_ON"]} RESCALE_LAYER {self.RESCALE_LAYER}\n'
-        )
-
-        args.MODEL_NAME = args.MODEL_NAME.strip()
-        if not args.MODEL_NAME.endswith(".pth"):
-            args.MODEL_NAME += ".pth"
-        prxxx(f"Loading {args.MODEL_NAME} ...")
-        with torch.no_grad():
-            self.w = torch.load(
-                args.MODEL_NAME, map_location="cpu"
-            )  # load model to CPU first
-            gc.collect()
-            w = self.w
-
-            ALREADY_CONVERTED = False
-            if "_strategy" in w:
-                ALREADY_CONVERTED = True
-                assert (
-                    convert_and_save_and_exit == None
-                )  # you should only convert a raw model
-                prxxx(
-                    f"Converted model: strategy {w['_strategy']}, version {w['_version']}\n"
-                )
-                assert (
-                    w["_strategy"] == args.strategy_string
-                )  # if you are using a new strategy, re-convert the model
-                assert (
-                    float(w["_version"]) >= 0.7
-                )  # sometimes you should re-convert using latest convert_model.py
-                assert (
-                    w["_rescale_layer"] == self.RESCALE_LAYER
-                )  # must use same RESCALE_LAYER to avoid mistakes
-                del w["_strategy"]
-                del w["_version"]
-                del w["_rescale_layer"]
-
-            args.n_embd = w["emb.weight"].shape[1]
-            args.n_att = w["blocks.0.att.key.weight"].shape[
-                0
-            ]  # note: transposed matrix
-            args.n_ffn = w["blocks.0.ffn.key.weight"].shape[
-                0
-            ]  # note: transposed matrix
-            args.n_layer = 0
-            keys = list(w.keys())
-            self.version = 4
-            for x in keys:
-                layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0
-                args.n_layer = max(args.n_layer, layer_id + 1)
-                if "ln_x" in x:
-                    self.version = max(5, self.version)
-                if "gate.weight" in x:
-                    self.version = max(5.1, self.version)
-                if int(self.version) == 5 and "att.time_decay" in x:
-                    args.n_head = w[x].shape[0]
-                    if len(w[x].shape) > 1:
-                        if w[x].shape[1] > 1:
-                            self.version = max(5.2, self.version)
-                if "time_maa" in x:
-                    self.version = max(6, self.version)
-                if int(self.version) == 6 and "time_faaaa" in x:
-                    args.n_head = w[x].shape[0]
-            prxxx(f"Model detected: v{self.version:.1f}")
-
-            ####################### Compute strategy
-
-            s = [x.strip().split(" ") for x in strategy.split("->")]
-            plan = [0] * len(s)
-            stream_i = -1
-            stream_count = 0
-            to_allocate = args.n_layer + 1
-            allocated = 0
-            free_slots = 0
-            for i in range(len(s)):
-                si = s[i]
-                si1 = si[1]
-                if si1.startswith("fp32"):
-                    si[1] = [torch.float]
-                elif si1.startswith("fp16"):
-                    si[1] = [torch.float16]
-                elif si1.startswith("bf16"):
-                    si[1] = [torch.bfloat16]
-                if si1.endswith("i8"):
-                    si[1] += [torch.uint8]
-                else:
-                    si[1] += [si[1][0]]
-                if len(si) > 2:
-                    ss = si[2]
-                    assert ss.startswith("*")
-                    if ss.endswith("+"):
-                        plan[i] = int(ss[1:-1])
-                        stream_i = i
-                    else:
-                        plan[i] = int(ss[1:])
-                    allocated += plan[i]
-                    if allocated >= to_allocate:
-                        plan[i] += to_allocate - allocated
-                        break
-                else:
-                    free_slots += 1
-            if stream_i < 0:
-                if free_slots > 0 and to_allocate > allocated:
-                    for i in range(len(s)):
-                        if plan[i] == 0:
-                            plan[i] = (to_allocate - allocated) // free_slots
-                            allocated += plan[i]
-                            free_slots -= 1
-                if to_allocate > allocated:
-                    plan[len(s) - 1] += to_allocate - allocated
-            else:
-                if to_allocate > allocated:
-                    stream_count = to_allocate - allocated
-                    plan[stream_i] += stream_count
-            prxxx(f"Strategy: (total {args.n_layer}+1={args.n_layer+1} layers)")
-            for i in range(len(s)):
-                ss = s[i]
-                if i != stream_i:
-                    prxxx(
-                        f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]} layers'
-                    )
-                else:
-                    prxxx(
-                        f'* {ss[0]} {str(ss[1]).replace("torch.","")}, store {plan[i]-stream_count} layers, stream {stream_count} layers'
-                    )
-                plan[i] += 0 if i == 0 else plan[i - 1]
-            self.strategy = [None] * (args.n_layer + 1)
-            strategy = self.strategy
-            for n in range(args.n_layer + 1):
-                for i in range(len(s)):
-                    if n < plan[i]:
-                        strategy[n] = types.SimpleNamespace()
-                        strategy[n].device = s[i][0]
-                        strategy[n].atype = s[i][1][0]
-                        strategy[n].wtype = s[i][1][1]
-                        strategy[n].stream = False
-                        if strategy[n].device == "dml":
-                            strategy[n].device = torch_directml.device()
-                        if i == stream_i and n >= (plan[i] - stream_count):
-                            strategy[n].stream = True
-                        break
-                prxxx(
-                    f"{n}-{strategy[n].device}-{str(strategy[n].atype).replace('torch.','')}-{str(strategy[n].wtype).replace('torch.','')}{'-stream' if strategy[n].stream else ''}",
-                    end=" ",
-                )
-            prxxx()
-
-            ####################### Load weights to self.w
-
-            if not ALREADY_CONVERTED:
-                try:  # precompute embedding
-                    w["emb.weight"] = F.layer_norm(
-                        w["emb.weight"],
-                        (args.n_embd,),
-                        weight=w["blocks.0.ln0.weight"],
-                        bias=w["blocks.0.ln0.bias"],
-                    )
-                except:
-                    w["emb.weight"] = F.layer_norm(
-                        w["emb.weight"].float(),
-                        (args.n_embd,),
-                        weight=w["blocks.0.ln0.weight"].float(),
-                        bias=w["blocks.0.ln0.bias"].float(),
-                    )
-                del w["blocks.0.ln0.weight"]
-                del w["blocks.0.ln0.bias"]
-
-            print_need_newline = False
-
-            REAL_TIME_FIRST = False
-            args.time_state = False
-            for x in list(w.keys()):
-                if ".time_faaaa" in x:
-                    REAL_TIME_FIRST = True
-                if ".time_state" in x:
-                    args.time_state = True
-            if REAL_TIME_FIRST:
-                w = {
-                    (
-                        k.replace(".time_faaaa", ".time_first")
-                        if ".time_faaaa" in k
-                        else k
-                    ): v
-                    for k, v in w.items()
-                }
-                self.w = w
-
-            keys = list(w.keys())
-            for x in keys:
-                w[x].requires_grad = False
-                layer_id = int(x.split(".")[1]) if ("blocks." in x) else 0
-                if ("ln_out." in x) or ("head." in x):
-                    layer_id = args.n_layer
-                dd = strategy[layer_id]
-                DEVICE = dd.device
-                ATYPE = dd.atype
-                WTYPE = dd.wtype
-
-                if not ALREADY_CONVERTED:
-                    if self.RESCALE_LAYER > 0:
-                        if "att.output.weight" in x:
-                            w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER))
-                        if "ffn.value.weight" in x:
-                            w[x] = w[x] / (2 ** int(layer_id // self.RESCALE_LAYER))
-
-                    if ".time_" in x:
-                        w[x] = w[x].squeeze()
-                    if (
-                        "key.weight" in x
-                        or "value.weight" in x
-                        or "receptance.weight" in x
-                        or "gate.weight" in x
-                        or "output.weight" in x
-                        or "head.weight" in x
-                    ):
-                        w[x] = w[x].t()
-
-                    if ".time_decay" in x and "_w" not in x:  # need fp32 for this
-                        if self.version == 4:
-                            w[x] = -torch.exp(w[x].float())
-                        elif int(self.version) == 5:
-                            w[x] = torch.exp(-torch.exp(w[x].float())).reshape(-1, 1, 1)
-                            if self.version == 5.2:
-                                w[x] = w[x].reshape(args.n_head, -1, 1)
-                        elif self.version == 6.0:
-                            w[x] = w[x].float().reshape(args.n_head, -1, 1)
-                    elif ".time_first" in x:  # need fp32 for this
-                        if self.version == 4:
-                            w[x] = w[x].float()
-                        elif int(self.version) in [5, 6]:
-                            if REAL_TIME_FIRST:
-                                w[x] = w[x].float().reshape(-1, 1, 1)
-                            else:
-                                w[x] = torch.exp(w[x].float()).reshape(-1, 1, 1)
-                            if self.version in [5.2, 6.0]:
-                                w[x] = w[x].reshape(args.n_head, -1, 1)
-                    elif ".ln_x" in x:  # need fp32 for group_norm
-                        w[x] = w[x].float()
-                    else:
-                        if (
-                            (len(w[x].shape) == 2)
-                            and ("emb" not in x)
-                            and ("_w1" not in x)
-                            and ("_w2" not in x)
-                        ):
-                            if WTYPE != torch.uint8:
-                                w[x] = w[x].to(dtype=WTYPE)
-                            else:
-                                w[x] = w[x].float()
-
-                                if w[x].shape[0] > w[x].shape[1]:
-                                    w[x + "_my"] = torch.amin(w[x], dim=1).unsqueeze(1)
-                                    w[x] = w[x] - w[x + "_my"]
-                                    w[x + "_mx"] = torch.amin(w[x], dim=0)
-                                    w[x] = w[x] - w[x + "_mx"]
-                                    w[x + "_rx"] = torch.amax(w[x], dim=0)
-                                    w[x] = w[x] / w[x + "_rx"]
-                                    w[x + "_ry"] = torch.amax(w[x], dim=1).unsqueeze(1)
-                                    w[x] = w[x] / w[x + "_ry"]
-                                else:
-                                    w[x + "_mx"] = torch.amin(w[x], dim=0)
-                                    w[x] = w[x] - w[x + "_mx"]
-                                    w[x + "_my"] = torch.amin(w[x], dim=1).unsqueeze(1)
-                                    w[x] = w[x] - w[x + "_my"]
-                                    w[x + "_rx"] = torch.amax(w[x], dim=0)
-                                    w[x] = w[x] / w[x + "_rx"]
-                                    w[x + "_ry"] = torch.amax(w[x], dim=1).unsqueeze(1)
-                                    w[x] = w[x] / w[x + "_ry"]
-
-                                w[x] = torch.clip(
-                                    torch.floor(w[x] * 256), min=0, max=255
-                                ).to(dtype=torch.uint8)
-                                w[x + "_mx"] = w[x + "_mx"].to(dtype=ATYPE).contiguous()
-                                w[x + "_rx"] = (
-                                    (w[x + "_rx"] / 16).to(dtype=ATYPE).contiguous()
-                                )
-                                w[x + "_my"] = w[x + "_my"].to(dtype=ATYPE).contiguous()
-                                w[x + "_ry"] = (
-                                    (w[x + "_ry"] / 16).to(dtype=ATYPE).contiguous()
-                                )
-                        else:
-                            w[x] = w[x].to(dtype=ATYPE)
-
-                if convert_and_save_and_exit == None:
-                    if "emb." in x:
-                        w[x] = w[x].contiguous()
-                    elif (dd.stream) and (
-                        x.endswith("key.weight")
-                        or x.endswith("value.weight")
-                        or x.endswith("receptance.weight")
-                        or x.endswith("output.weight")
-                    ):
-                        try:
-                            w[x] = (
-                                w[x].contiguous().pin_memory()
-                            )  # if you see "CUDA error: out of memory" here, that's out of CPU RAM, not VRAM. Get more RAM :)
-                        except:
-                            print(
-                                "Note: You are running out of RAM. Get more CPU RAM. Now this will run much slower."
-                            )
-                    elif DEVICE != "cpu":
-                        w[x] = w[x].to(device=DEVICE).contiguous()
-
-                    if (dd.stream) or (DEVICE != "cpu"):
-                        try:
-                            w[x + "_mx"] = w[x + "_mx"].to(device=DEVICE).contiguous()
-                            w[x + "_rx"] = w[x + "_rx"].to(device=DEVICE).contiguous()
-                            w[x + "_my"] = w[x + "_my"].to(device=DEVICE).contiguous()
-                            w[x + "_ry"] = w[x + "_ry"].to(device=DEVICE).contiguous()
-                        except:
-                            pass
-
-                if "ffn.value.weight" in x:
-                    gc.collect()
-                    if "cuda" in args.strategy_string:
-                        torch.cuda.empty_cache()
-
-                shape = [i for i in w[x].shape if i != 1]
-                if len(shape) > 2:
-                    shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)} {str(shape[2]).rjust(5)}"
-                elif len(shape) > 1:
-                    shape = f" {str(shape[0]).rjust(5)} {str(shape[1]).rjust(5)}      "
-                else:
-                    shape = f" {str(shape[0]).rjust(5)}            "
-                if layer_id == 0 or layer_id >= args.n_layer - 1:
-                    if print_need_newline:
-                        prxxx("\n", end="")
-                        print_need_newline = False
-                    dt = str(w[x].dtype).replace("torch.", "")
-                    dt = (
-                        dt.replace("float32", "f32")
-                        .replace("bfloat16", "bf16")
-                        .replace("float16", "f16")
-                        .replace("uint8", "i8")
-                    )
-                    prxxx(
-                        x.ljust(32),
-                        dt.rjust(4),
-                        str(w[x].device).rjust(8),
-                        shape,
-                        " (pinned)" if w[x].is_pinned() else "",
-                    )
-                else:
-                    print_need_newline = True
-                    prxxx(".", end="", flush=True)
-
-            if convert_and_save_and_exit:
-                w["_strategy"] = args.strategy_string
-                w["_rescale_layer"] = self.RESCALE_LAYER
-                w["_version"] = "0.7"
-                if not convert_and_save_and_exit.endswith(".pth"):
-                    convert_and_save_and_exit += ".pth"
-                prxxx(f"Saving to {convert_and_save_and_exit}...")
-                torch.save(w, convert_and_save_and_exit)
-                prxxx(f"Converted and saved. Now this will exit.")
-                exit(0)
-
-            if self.version == 5.2 and os.environ["RWKV_CUDA_ON"] == "1":
-                HEAD_SIZE = args.n_att // args.n_head
-                rwkv5 = load(
-                    name="rwkv5",
-                    sources=[
-                        f"{current_path}/cuda/rwkv5_op.cpp",
-                        f"{current_path}/cuda/rwkv5.cu",
-                    ],
-                    verbose=True,
-                    extra_cuda_cflags=[
-                        "-res-usage",
-                        "--use_fast_math",
-                        "-O3",
-                        "-Xptxas -O3" if os.name != "nt" else "",
-                        "--extra-device-vectorization",
-                        f"-D_N_={HEAD_SIZE}",
-                    ],
-                )
-
-                class RWKV_5(torch.autograd.Function):
-                    @staticmethod
-                    def forward(ctx, B, T, C, H, state, r, k, v, w, u):
-                        with torch.no_grad():
-                            assert HEAD_SIZE == C // H
-                            ctx.B = B
-                            ctx.T = T
-                            ctx.C = C
-                            ctx.H = H
-                            assert state.dtype == torch.float32
-                            assert w.dtype == torch.float32
-                            assert r.is_contiguous()
-                            assert k.is_contiguous()
-                            assert v.is_contiguous()
-                            assert w.is_contiguous()
-                            assert u.is_contiguous()
-                            assert state.is_contiguous()
-
-                            y = torch.empty(
-                                (B, T, C),
-                                device=w.device,
-                                dtype=r.dtype,
-                                memory_format=torch.contiguous_format,
-                            )
-                            if r.dtype == torch.bfloat16:
-                                rwkv5.forward_bf16(B, T, C, H, state, r, k, v, w, u, y)
-                            elif r.dtype == torch.float16:
-                                rwkv5.forward_fp16(B, T, C, H, state, r, k, v, w, u, y)
-                            elif r.dtype == torch.float32:
-                                rwkv5.forward_fp32(B, T, C, H, state, r, k, v, w, u, y)
-                            return y, state
-
-                self.RWKV_5 = RWKV_5
-
-            if self.version == 6.0 and os.environ["RWKV_CUDA_ON"] == "1":
-                HEAD_SIZE = args.n_att // args.n_head
-                rwkv6 = load(
-                    name="rwkv6",
-                    sources=[
-                        f"{current_path}/cuda/rwkv6_op.cpp",
-                        f"{current_path}/cuda/rwkv6.cu",
-                    ],
-                    verbose=True,
-                    extra_cuda_cflags=[
-                        "-res-usage",
-                        "--use_fast_math",
-                        "-O3",
-                        "-Xptxas -O3" if os.name != "nt" else "",
-                        "--extra-device-vectorization",
-                        f"-D_N_={HEAD_SIZE}",
-                        f"-D_T_={4096}",
-                    ],
-                )
-
-                class RWKV_6(torch.autograd.Function):
-                    @staticmethod
-                    def forward(ctx, B, T, C, H, state, r, k, v, w, u):
-                        with torch.no_grad():
-                            assert HEAD_SIZE == C // H
-                            ctx.B = B
-                            ctx.T = T
-                            ctx.C = C
-                            ctx.H = H
-                            assert state.dtype == torch.float32
-                            assert w.dtype == torch.float32
-                            assert r.is_contiguous()
-                            assert k.is_contiguous()
-                            assert v.is_contiguous()
-                            assert w.is_contiguous()
-                            assert u.is_contiguous()
-                            eew = torch.exp(-torch.exp(w.float())).contiguous()
-
-                            y = torch.empty(
-                                (B, T, C),
-                                device=w.device,
-                                dtype=r.dtype,
-                                memory_format=torch.contiguous_format,
-                            )
-                            if r.dtype == torch.bfloat16:
-                                rwkv6.forward_bf16(
-                                    B, T, C, H, state, r, k, v, eew, u, y
-                                )
-                            elif r.dtype == torch.float16:
-                                rwkv6.forward_fp16(
-                                    B, T, C, H, state, r, k, v, eew, u, y
-                                )
-                            elif r.dtype == torch.float32:
-                                rwkv6.forward_fp32(
-                                    B, T, C, H, state, r, k, v, eew, u, y
-                                )
-                            return y, state
-
-                self.RWKV_6 = RWKV_6
-
-            gc.collect()
-            if "cuda" in args.strategy_string:
-                torch.cuda.empty_cache()
-
-    def RUN_RWKV_5(self, B, T, C, H, state, r, k, v, w, u):
-        return self.RWKV_5.apply(B, T, C, H, state, r, k, v, w, u)
-
-    def RUN_RWKV_6(self, B, T, C, H, state, r, k, v, w, u):
-        return self.RWKV_6.apply(B, T, C, H, state, r, k, v, w, u)
-
-    ########################################################################################################
-
-    @MyFunction
-    def ffn_one(
-        self,
-        x,
-        sx,
-        ln_w,
-        ln_b,
-        k_mix,
-        r_mix,
-        kw,
-        vw,
-        rw,
-        kmx,
-        krx,
-        kmy,
-        kry,
-        vmx,
-        vrx,
-        vmy,
-        vry,
-        rmx,
-        rrx,
-        rmy,
-        rry,
-    ):
-        xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
-        kx = xx * k_mix + sx * (1 - k_mix)
-        rx = xx * r_mix + sx * (1 - r_mix)
-
-        r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
-        vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2
-        out = r * matmul(vx, vw, vmx, vrx, vmy, vry)
-        return x + out, xx
-
-    @MyFunction
-    def ffn_seq(
-        self,
-        x,
-        sx,
-        ln_w,
-        ln_b,
-        k_mix,
-        r_mix,
-        kw,
-        vw,
-        rw,
-        kmx,
-        krx,
-        kmy,
-        kry,
-        vmx,
-        vrx,
-        vmy,
-        vry,
-        rmx,
-        rrx,
-        rmy,
-        rry,
-    ):
-        xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
-        sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
-        kx = xx * k_mix + sx * (1 - k_mix)
-        rx = xx * r_mix + sx * (1 - r_mix)
-
-        r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
-        vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2
-        out = r * matmul(vx, vw, vmx, vrx, vmy, vry)
-        return x + out, xx[-1, :]
-
-    @MyFunction
-    def ffn_one_v6(
-        self,
-        x,
-        sx,
-        ln_w,
-        ln_b,
-        k_maa,
-        r_maa,
-        kw,
-        vw,
-        rw,
-        kmx,
-        krx,
-        kmy,
-        kry,
-        vmx,
-        vrx,
-        vmy,
-        vry,
-        rmx,
-        rrx,
-        rmy,
-        rry,
-    ):
-        xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
-        sx = sx - xx
-        kx = xx + sx * k_maa
-        rx = xx + sx * r_maa
-
-        r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
-        vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2
-        out = r * matmul(vx, vw, vmx, vrx, vmy, vry)
-        return x + out, xx
-
-    @MyFunction
-    def ffn_seq_v6(
-        self,
-        x,
-        sx,
-        ln_w,
-        ln_b,
-        k_maa,
-        r_maa,
-        kw,
-        vw,
-        rw,
-        kmx,
-        krx,
-        kmy,
-        kry,
-        vmx,
-        vrx,
-        vmy,
-        vry,
-        rmx,
-        rrx,
-        rmy,
-        rry,
-    ):
-        xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
-        sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
-        sx = sx - xx
-        kx = xx + sx * k_maa
-        rx = xx + sx * r_maa
-
-        r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
-        vx = torch.relu(matmul(kx, kw, kmx, krx, kmy, kry)) ** 2
-        out = r * matmul(vx, vw, vmx, vrx, vmy, vry)
-        return x + out, xx[-1, :]
-
-    ########################################################################################################
-
-    @MyFunction
-    def att_one(
-        self,
-        x,
-        sx,
-        aa,
-        bb,
-        pp,
-        ln_w,
-        ln_b,
-        k_mix,
-        v_mix,
-        r_mix,
-        t_decay,
-        t_first,
-        kw,
-        vw,
-        rw,
-        ow,
-        kmx,
-        krx,
-        kmy,
-        kry,
-        vmx,
-        vrx,
-        vmy,
-        vry,
-        rmx,
-        rrx,
-        rmy,
-        rry,
-        omx,
-        orx,
-        omy,
-        ory,
-    ):
-        xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
-        kx = xx * k_mix + sx * (1 - k_mix)
-        vx = xx * v_mix + sx * (1 - v_mix)
-        rx = xx * r_mix + sx * (1 - r_mix)
-
-        r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
-        k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
-        v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
-
-        ww = t_first + k
-        p = torch.maximum(pp, ww)
-        e1 = torch.exp(pp - p)
-        e2 = torch.exp(ww - p)
-        wkv = ((e1 * aa + e2 * v) / (e1 * bb + e2)).to(dtype=x.dtype)
-        ww = t_decay + pp
-        p = torch.maximum(ww, k)
-        e1 = torch.exp(ww - p)
-        e2 = torch.exp(k - p)
-
-        out = matmul(r * wkv, ow, omx, orx, omy, ory)
-        return x + out, xx, e1 * aa + e2 * v, e1 * bb + e2, p
-
-    @MyFunction
-    def att_seq(
-        self,
-        x,
-        sx,
-        aa,
-        bb,
-        pp,
-        ln_w,
-        ln_b,
-        k_mix,
-        v_mix,
-        r_mix,
-        t_decay,
-        t_first,
-        kw,
-        vw,
-        rw,
-        ow,
-        kmx,
-        krx,
-        kmy,
-        kry,
-        vmx,
-        vrx,
-        vmy,
-        vry,
-        rmx,
-        rrx,
-        rmy,
-        rry,
-        omx,
-        orx,
-        omy,
-        ory,
-    ):
-        xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
-        sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
-        kx = xx * k_mix + sx * (1 - k_mix)
-        vx = xx * v_mix + sx * (1 - v_mix)
-        rx = xx * r_mix + sx * (1 - r_mix)
-
-        r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
-        k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
-        v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
-
-        T = x.shape[0]
-        for t in range(T):
-            kk = k[t]
-            vv = v[t]
-            ww = t_first + kk
-            p = torch.maximum(pp, ww)
-            e1 = torch.exp(pp - p)
-            e2 = torch.exp(ww - p)
-            sx[t] = ((e1 * aa + e2 * vv) / (e1 * bb + e2)).to(dtype=x.dtype)
-            ww = t_decay + pp
-            p = torch.maximum(ww, kk)
-            e1 = torch.exp(ww - p)
-            e2 = torch.exp(kk - p)
-            aa = e1 * aa + e2 * vv
-            bb = e1 * bb + e2
-            pp = p
-        out = matmul(r * sx, ow, omx, orx, omy, ory)
-        return x + out, xx[-1, :], aa, bb, pp
-
-    ########################################################################################################
-
-    @MyFunction
-    def att_one_v5(
-        self,
-        x,
-        sx,
-        s,
-        ln_w,
-        ln_b,
-        lx_w,
-        lx_b,
-        k_mix,
-        v_mix,
-        r_mix,
-        t_decay,
-        t_first,
-        kw,
-        vw,
-        rw,
-        ow,
-        kmx,
-        krx,
-        kmy,
-        kry,
-        vmx,
-        vrx,
-        vmy,
-        vry,
-        rmx,
-        rrx,
-        rmy,
-        rry,
-        omx,
-        orx,
-        omy,
-        ory,
-    ):
-        xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
-        kx = xx * k_mix + sx * (1 - k_mix)
-        vx = xx * v_mix + sx * (1 - v_mix)
-        rx = xx * r_mix + sx * (1 - r_mix)
-
-        H = t_decay.shape[0]
-        N = x.shape[-1] // H
-
-        r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, N)
-        k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, N, 1)
-        v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, N)
-
-        a = matmul(k, v)
-        out = r @ (t_first * a + s)
-        s = a + t_decay * s
-
-        out = out.flatten()
-        out = F.group_norm(
-            out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5
-        ).squeeze(0)
-        out = out.to(dtype=x.dtype)
-        out = matmul(out, ow, omx, orx, omy, ory)
-
-        return x + out, xx, s
-
-    @MyFunction
-    def att_seq_v5(
-        self,
-        x,
-        sx,
-        s,
-        ln_w,
-        ln_b,
-        lx_w,
-        lx_b,
-        k_mix,
-        v_mix,
-        r_mix,
-        t_decay,
-        t_first,
-        kw,
-        vw,
-        rw,
-        ow,
-        kmx,
-        krx,
-        kmy,
-        kry,
-        vmx,
-        vrx,
-        vmy,
-        vry,
-        rmx,
-        rrx,
-        rmy,
-        rry,
-        omx,
-        orx,
-        omy,
-        ory,
-    ):
-        xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
-        sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
-        kx = xx * k_mix + sx * (1 - k_mix)
-        vx = xx * v_mix + sx * (1 - v_mix)
-        rx = xx * r_mix + sx * (1 - r_mix)
-
-        H = t_decay.shape[0]
-        N = x.shape[-1] // H
-        T = x.shape[0]
-
-        w = t_decay.reshape(-1, 1)
-        u = t_first.reshape(-1, 1)
-        ws = w.pow(T).reshape(H, 1, 1)
-        ind = torch.arange(T - 1, -1, -1, device=w.device).unsqueeze(0).repeat(H, 1)
-        w = w.repeat(1, T).pow(ind)
-        wk = w.reshape(H, 1, T)
-        wb = wk.transpose(-2, -1).flip(1)
-        w = torch.cat([w[:, 1:], u], dim=1)
-        w = F.pad(w, (0, T))
-        w = torch.tile(w, [T])
-        w = w[:, :-T].reshape(-1, T, 2 * T - 1)
-        w = w[:, :, T - 1 :].reshape(H, T, T)
-
-        r = (
-            matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
-            .view(T, H, N)
-            .transpose(0, 1)
-        )
-        k = (
-            matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
-            .view(T, H, N)
-            .permute(1, 2, 0)
-        )
-        v = (
-            matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
-            .view(T, H, N)
-            .transpose(0, 1)
-        )
-
-        out = ((r @ k) * w) @ v + (r @ s) * wb
-        s = ws * s + (k * wk) @ v
-
-        out = out.transpose(0, 1).contiguous().reshape(T, H * N)
-        out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5)
-        out = out.to(dtype=x.dtype)
-        out = matmul(out, ow, omx, orx, omy, ory)
-
-        return x + out, xx[-1, :], s
-
-    ########################################################################################################
-
-    @MyFunction
-    def att_one_v5_1(
-        self,
-        x,
-        sx,
-        s,
-        ln_w,
-        ln_b,
-        lx_w,
-        lx_b,
-        k_mix,
-        v_mix,
-        r_mix,
-        g_mix,
-        t_decay,
-        t_first,
-        kw,
-        vw,
-        rw,
-        gw,
-        ow,
-        kmx,
-        krx,
-        kmy,
-        kry,
-        vmx,
-        vrx,
-        vmy,
-        vry,
-        rmx,
-        rrx,
-        rmy,
-        rry,
-        gmx,
-        grx,
-        gmy,
-        gry,
-        omx,
-        orx,
-        omy,
-        ory,
-    ):
-        xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
-        kx = xx * k_mix + sx * (1 - k_mix)
-        vx = xx * v_mix + sx * (1 - v_mix)
-        rx = xx * r_mix + sx * (1 - r_mix)
-        gx = xx * g_mix + sx * (1 - g_mix)
-
-        H = t_decay.shape[0]
-        N = x.shape[-1] // H
-
-        r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, N)
-        k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, N, 1)
-        v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, N)
-        g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
-
-        a = matmul(k, v)
-        out = r @ (t_first * a + s)
-        s = a + t_decay * s
-
-        out = out.flatten()
-        out = F.group_norm(
-            out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5
-        ).squeeze(0)
-        out = out.to(dtype=x.dtype) * g
-        out = matmul(out, ow, omx, orx, omy, ory)
-
-        return x + out, xx, s
-
-    @MyFunction
-    def att_seq_v5_1(
-        self,
-        x,
-        sx,
-        s,
-        ln_w,
-        ln_b,
-        lx_w,
-        lx_b,
-        k_mix,
-        v_mix,
-        r_mix,
-        g_mix,
-        t_decay,
-        t_first,
-        kw,
-        vw,
-        rw,
-        gw,
-        ow,
-        kmx,
-        krx,
-        kmy,
-        kry,
-        vmx,
-        vrx,
-        vmy,
-        vry,
-        rmx,
-        rrx,
-        rmy,
-        rry,
-        gmx,
-        grx,
-        gmy,
-        gry,
-        omx,
-        orx,
-        omy,
-        ory,
-    ):
-        xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
-        sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
-        kx = xx * k_mix + sx * (1 - k_mix)
-        vx = xx * v_mix + sx * (1 - v_mix)
-        rx = xx * r_mix + sx * (1 - r_mix)
-        gx = xx * g_mix + sx * (1 - g_mix)
-
-        H = t_decay.shape[0]
-        N = x.shape[-1] // H
-        T = x.shape[0]
-
-        w = t_decay.reshape(-1, 1)
-        u = t_first.reshape(-1, 1)
-        ws = w.pow(T).reshape(H, 1, 1)
-        ind = torch.arange(T - 1, -1, -1, device=w.device).unsqueeze(0).repeat(H, 1)
-        w = w.repeat(1, T).pow(ind)
-        wk = w.reshape(H, 1, T)
-        wb = wk.transpose(-2, -1).flip(1)
-        w = torch.cat([w[:, 1:], u], dim=1)
-        w = F.pad(w, (0, T))
-        w = torch.tile(w, [T])
-        w = w[:, :-T].reshape(-1, T, 2 * T - 1)
-        w = w[:, :, T - 1 :].reshape(H, T, T)
-
-        r = (
-            matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
-            .view(T, H, N)
-            .transpose(0, 1)
-        )
-        k = (
-            matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
-            .view(T, H, N)
-            .permute(1, 2, 0)
-        )
-        v = (
-            matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
-            .view(T, H, N)
-            .transpose(0, 1)
-        )
-        g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
-
-        out = ((r @ k) * w) @ v + (r @ s) * wb
-        s = ws * s + (k * wk) @ v
-
-        out = out.transpose(0, 1).contiguous().reshape(T, H * N)
-        out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5)
-        out = out.to(dtype=x.dtype) * g
-        out = matmul(out, ow, omx, orx, omy, ory)
-
-        return x + out, xx[-1, :], s
-
-    ########################################################################################################
-
-    @MyFunction
-    def att_seq_v5_2(
-        self,
-        x,
-        sx,
-        s,
-        ln_w,
-        ln_b,
-        lx_w,
-        lx_b,
-        k_mix,
-        v_mix,
-        r_mix,
-        g_mix,
-        t_decay,
-        t_first,
-        kw,
-        vw,
-        rw,
-        gw,
-        ow,
-        kmx,
-        krx,
-        kmy,
-        kry,
-        vmx,
-        vrx,
-        vmy,
-        vry,
-        rmx,
-        rrx,
-        rmy,
-        rry,
-        gmx,
-        grx,
-        gmy,
-        gry,
-        omx,
-        orx,
-        omy,
-        ory,
-    ):
-        xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
-        sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
-        kx = xx * k_mix + sx * (1 - k_mix)
-        vx = xx * v_mix + sx * (1 - v_mix)
-        rx = xx * r_mix + sx * (1 - r_mix)
-        gx = xx * g_mix + sx * (1 - g_mix)
-
-        H = t_decay.shape[0]
-        N = x.shape[-1] // H
-        T = x.shape[0]
-
-        r = (
-            matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
-            .view(T, H, N)
-            .transpose(0, 1)
-        )
-        k = (
-            matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
-            .view(T, H, N)
-            .permute(1, 2, 0)
-        )
-        v = (
-            matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
-            .view(T, H, N)
-            .transpose(0, 1)
-        )
-        g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
-
-        out = torch.empty((T, H, N), dtype=r.dtype, device=r.device)
-        for t in range(T):
-            rt = r[:, t : t + 1, :]
-            kt = k[:, :, t : t + 1]
-            vt = v[:, t : t + 1, :]
-            at = matmul(kt, vt)
-            out[t] = (rt @ (t_first * at + s)).squeeze(1)
-            s = at + t_decay * s
-
-        out = out.reshape(T, H * N)
-        out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5)
-        out = out.to(dtype=x.dtype) * g
-        out = matmul(out, ow, omx, orx, omy, ory)
-
-        return x + out, xx[-1, :], s
-
-    ########################################################################################################
-
-    @MyFunction
-    def att_one_v6_0(
-        self,
-        x,
-        sx,
-        s,
-        ln_w,
-        ln_b,
-        lx_w,
-        lx_b,
-        x_maa,
-        w_maa,
-        k_maa,
-        v_maa,
-        r_maa,
-        g_maa,
-        tm_w1,
-        tm_w2,
-        td_w1,
-        td_w2,
-        t_decay,
-        t_first,
-        kw,
-        vw,
-        rw,
-        gw,
-        ow,
-        kmx,
-        krx,
-        kmy,
-        kry,
-        vmx,
-        vrx,
-        vmy,
-        vry,
-        rmx,
-        rrx,
-        rmy,
-        rry,
-        gmx,
-        grx,
-        gmy,
-        gry,
-        omx,
-        orx,
-        omy,
-        ory,
-    ):
-        xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
-
-        sx = sx - xx
-        xxx = xx + sx * x_maa
-        xxx = torch.tanh(xxx @ tm_w1).view(5, 1, -1)
-        xxx = torch.bmm(xxx, tm_w2).view(5, -1)
-        mw, mk, mv, mr, mg = xxx.unbind(dim=0)
-
-        wx = xx + sx * (w_maa + mw)
-        kx = xx + sx * (k_maa + mk)
-        vx = xx + sx * (v_maa + mv)
-        rx = xx + sx * (r_maa + mr)
-        gx = xx + sx * (g_maa + mg)
-
-        H = t_decay.shape[0]
-        N = x.shape[-1] // H
-
-        r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32).view(H, 1, N)
-        k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32).view(H, N, 1)
-        v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32).view(H, 1, N)
-        g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
-
-        w = t_decay + (torch.tanh(wx @ td_w1) @ td_w2).float().view(H, N, 1)
-        w = torch.exp(-torch.exp(w.float()))
-
-        a = matmul(k, v)
-        out = r @ (t_first * a + s)
-        s = a + w * s
-
-        out = out.flatten()
-        out = F.group_norm(
-            out.unsqueeze(0), num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5
-        ).squeeze(0)
-        out = out.to(dtype=x.dtype) * g
-        out = matmul(out, ow, omx, orx, omy, ory)
-
-        return x + out, xx, s
-
-    @MyFunction
-    def att_seq_v6_0(
-        self,
-        x,
-        sx,
-        s,
-        ln_w,
-        ln_b,
-        lx_w,
-        lx_b,
-        x_maa,
-        w_maa,
-        k_maa,
-        v_maa,
-        r_maa,
-        g_maa,
-        tm_w1,
-        tm_w2,
-        td_w1,
-        td_w2,
-        t_decay,
-        t_first,
-        kw,
-        vw,
-        rw,
-        gw,
-        ow,
-        kmx,
-        krx,
-        kmy,
-        kry,
-        vmx,
-        vrx,
-        vmy,
-        vry,
-        rmx,
-        rrx,
-        rmy,
-        rry,
-        gmx,
-        grx,
-        gmy,
-        gry,
-        omx,
-        orx,
-        omy,
-        ory,
-    ):
-        H = t_decay.shape[0]
-        N = x.shape[-1] // H
-        T = x.shape[0]
-
-        xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
-        sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) - xx
-        xxx = xx + sx * x_maa
-        xxx = torch.tanh(xxx @ tm_w1).view(T, 5, -1).transpose(0, 1)
-        xxx = torch.bmm(xxx, tm_w2).view(5, T, -1)
-        mw, mk, mv, mr, mg = xxx.unbind(dim=0)
-
-        wx = xx + sx * (w_maa + mw)
-        kx = xx + sx * (k_maa + mk)
-        vx = xx + sx * (v_maa + mv)
-        rx = xx + sx * (r_maa + mr)
-        gx = xx + sx * (g_maa + mg)
-
-        r = (
-            matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
-            .view(T, H, N)
-            .transpose(0, 1)
-        )
-        k = (
-            matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
-            .view(T, H, N)
-            .permute(1, 2, 0)
-        )
-        v = (
-            matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
-            .view(T, H, N)
-            .transpose(0, 1)
-        )
-        g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
-
-        w = t_decay.view(1, H, N, 1) + (torch.tanh(wx @ td_w1) @ td_w2).float().view(
-            T, H, N, 1
-        )
-        w = torch.exp(-torch.exp(w.float()))
-        out = torch.empty((T, H, N), dtype=r.dtype, device=r.device)
-        for t in range(T):
-            rt = r[:, t : t + 1, :]
-            kt = k[:, :, t : t + 1]
-            vt = v[:, t : t + 1, :]
-            at = matmul(kt, vt)
-            out[t] = (rt @ (t_first * at + s)).squeeze(1)
-            s = at + w[t] * s
-
-        out = out.reshape(T, H * N)
-        out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5)
-        out = out.to(dtype=x.dtype) * g
-        out = matmul(out, ow, omx, orx, omy, ory)
-
-        return x + out, xx[-1, :], s
-
-    ########################################################################################################
-
-    if os.environ["RWKV_CUDA_ON"] == "1":
-
-        @MyFunction
-        def cuda_att_seq(
-            self,
-            x,
-            sx,
-            aa,
-            bb,
-            pp,
-            ln_w,
-            ln_b,
-            k_mix,
-            v_mix,
-            r_mix,
-            t_decay,
-            t_first,
-            kw,
-            vw,
-            rw,
-            ow,
-            kmx,
-            krx,
-            kmy,
-            kry,
-            vmx,
-            vrx,
-            vmy,
-            vry,
-            rmx,
-            rrx,
-            rmy,
-            rry,
-            omx,
-            orx,
-            omy,
-            ory,
-        ):
-            T, C = x.shape
-            xx = F.layer_norm(x, (C,), weight=ln_w, bias=ln_b)
-            sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
-            kx = xx * k_mix + sx * (1 - k_mix)
-            vx = xx * v_mix + sx * (1 - v_mix)
-            rx = xx * r_mix + sx * (1 - r_mix)
-
-            r = torch.sigmoid(matmul(rx, rw, rmx, rrx, rmy, rry))
-            k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
-            v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
-            y, aa, bb, pp = cuda_wkv(T, C, t_decay, t_first, k, v, aa, bb, pp)
-
-            out = matmul(r * y.to(x.dtype), ow, omx, orx, omy, ory)
-            return x + out, xx[-1, :], aa, bb, pp
-
-        @MyFunction
-        def v5_2_before(
-            self,
-            x,
-            sx,
-            s,
-            ln_w,
-            ln_b,
-            lx_w,
-            lx_b,
-            k_mix,
-            v_mix,
-            r_mix,
-            g_mix,
-            t_decay,
-            t_first,
-            kw,
-            vw,
-            rw,
-            gw,
-            ow,
-            kmx,
-            krx,
-            kmy,
-            kry,
-            vmx,
-            vrx,
-            vmy,
-            vry,
-            rmx,
-            rrx,
-            rmy,
-            rry,
-            gmx,
-            grx,
-            gmy,
-            gry,
-            omx,
-            orx,
-            omy,
-            ory,
-        ):
-            xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
-            sx = torch.cat((sx.unsqueeze(0), xx[:-1, :]))
-            kx = xx * k_mix + sx * (1 - k_mix)
-            vx = xx * v_mix + sx * (1 - v_mix)
-            rx = xx * r_mix + sx * (1 - r_mix)
-            gx = xx * g_mix + sx * (1 - g_mix)
-
-            r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
-            k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
-            v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
-            g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
-
-            return r, k, v, g, xx[-1, :], s.transpose(-1, -2).contiguous()
-
-        @MyFunction
-        def v5_2_after(
-            self, t_decay, out, s, x, xxx, g, lx_w, lx_b, ow, omx, orx, omy, ory
-        ):
-            H = t_decay.shape[0]
-            N = x.shape[-1] // H
-            T = x.shape[0]
-
-            s = s.transpose(-1, -2)
-            out = out.reshape(T, H * N)
-            out = F.group_norm(out, num_groups=H, weight=lx_w, bias=lx_b, eps=64e-5)
-            out = out.to(dtype=x.dtype) * g
-            out = matmul(out, ow, omx, orx, omy, ory)
-
-            return x + out, xxx, s
-
-        def cuda_att_seq_v5_2(
-            self,
-            x,
-            sx,
-            s,
-            ln_w,
-            ln_b,
-            lx_w,
-            lx_b,
-            k_mix,
-            v_mix,
-            r_mix,
-            g_mix,
-            t_decay,
-            t_first,
-            kw,
-            vw,
-            rw,
-            gw,
-            ow,
-            kmx,
-            krx,
-            kmy,
-            kry,
-            vmx,
-            vrx,
-            vmy,
-            vry,
-            rmx,
-            rrx,
-            rmy,
-            rry,
-            gmx,
-            grx,
-            gmy,
-            gry,
-            omx,
-            orx,
-            omy,
-            ory,
-        ):
-            H = t_decay.shape[0]
-            N = x.shape[-1] // H
-            T = x.shape[0]
-
-            r, k, v, g, xxx, ss = self.v5_2_before(
-                x,
-                sx,
-                s,
-                ln_w,
-                ln_b,
-                lx_w,
-                lx_b,
-                k_mix,
-                v_mix,
-                r_mix,
-                g_mix,
-                t_decay,
-                t_first,
-                kw,
-                vw,
-                rw,
-                gw,
-                ow,
-                kmx,
-                krx,
-                kmy,
-                kry,
-                vmx,
-                vrx,
-                vmy,
-                vry,
-                rmx,
-                rrx,
-                rmy,
-                rry,
-                gmx,
-                grx,
-                gmy,
-                gry,
-                omx,
-                orx,
-                omy,
-                ory,
-            )
-
-            out, s = self.RUN_RWKV_5(
-                1, T, self.args.n_att, H, ss, r, k, v, w=t_decay, u=t_first
-            )
-
-            return self.v5_2_after(
-                t_decay, out, s, x, xxx, g, lx_w, lx_b, ow, omx, orx, omy, ory
-            )
-
-        @MyFunction
-        def v6_0_before(
-            self,
-            x,
-            sx,
-            s,
-            ln_w,
-            ln_b,
-            lx_w,
-            lx_b,
-            x_maa,
-            w_maa,
-            k_maa,
-            v_maa,
-            r_maa,
-            g_maa,
-            tm_w1,
-            tm_w2,
-            td_w1,
-            td_w2,
-            t_decay,
-            t_first,
-            kw,
-            vw,
-            rw,
-            gw,
-            ow,
-            kmx,
-            krx,
-            kmy,
-            kry,
-            vmx,
-            vrx,
-            vmy,
-            vry,
-            rmx,
-            rrx,
-            rmy,
-            rry,
-            gmx,
-            grx,
-            gmy,
-            gry,
-            omx,
-            orx,
-            omy,
-            ory,
-        ):
-            H = t_decay.shape[0]
-            N = x.shape[-1] // H
-            T = x.shape[0]
-
-            xx = F.layer_norm(x, (x.shape[-1],), weight=ln_w, bias=ln_b)
-            sx = torch.cat((sx.unsqueeze(0), xx[:-1, :])) - xx
-            xxx = xx + sx * x_maa
-            xxx = torch.tanh(xxx @ tm_w1).view(T, 5, -1).transpose(0, 1)
-            xxx = torch.bmm(xxx, tm_w2).view(5, T, -1)
-            mw, mk, mv, mr, mg = xxx.unbind(dim=0)
-
-            wx = xx + sx * (w_maa + mw)
-            kx = xx + sx * (k_maa + mk)
-            vx = xx + sx * (v_maa + mv)
-            rx = xx + sx * (r_maa + mr)
-            gx = xx + sx * (g_maa + mg)
-
-            r = matmul(rx, rw, rmx, rrx, rmy, rry, output_dtype=torch.float32)
-            k = matmul(kx, kw, kmx, krx, kmy, kry, output_dtype=torch.float32)
-            v = matmul(vx, vw, vmx, vrx, vmy, vry, output_dtype=torch.float32)
-            g = F.silu(matmul(gx, gw, gmx, grx, gmy, gry))
-
-            w = t_decay.view(1, H, N, 1) + (
-                torch.tanh(wx @ td_w1) @ td_w2
-            ).float().view(T, H, N, 1)
-
-            return r, k, v, g, w, xx[-1, :], s.transpose(-1, -2).contiguous()
-
-        def cuda_att_seq_v6_0(
-            self,
-            x,
-            sx,
-            s,
-            ln_w,
-            ln_b,
-            lx_w,
-            lx_b,
-            x_maa,
-            w_maa,
-            k_maa,
-            v_maa,
-            r_maa,
-            g_maa,
-            tm_w1,
-            tm_w2,
-            td_w1,
-            td_w2,
-            t_decay,
-            t_first,
-            kw,
-            vw,
-            rw,
-            gw,
-            ow,
-            kmx,
-            krx,
-            kmy,
-            kry,
-            vmx,
-            vrx,
-            vmy,
-            vry,
-            rmx,
-            rrx,
-            rmy,
-            rry,
-            gmx,
-            grx,
-            gmy,
-            gry,
-            omx,
-            orx,
-            omy,
-            ory,
-        ):
-            H = t_decay.shape[0]
-            N = x.shape[-1] // H
-            T = x.shape[0]
-
-            r, k, v, g, w, xxx, ss = self.v6_0_before(
-                x,
-                sx,
-                s,
-                ln_w,
-                ln_b,
-                lx_w,
-                lx_b,
-                x_maa,
-                w_maa,
-                k_maa,
-                v_maa,
-                r_maa,
-                g_maa,
-                tm_w1,
-                tm_w2,
-                td_w1,
-                td_w2,
-                t_decay,
-                t_first,
-                kw,
-                vw,
-                rw,
-                gw,
-                ow,
-                kmx,
-                krx,
-                kmy,
-                kry,
-                vmx,
-                vrx,
-                vmy,
-                vry,
-                rmx,
-                rrx,
-                rmy,
-                rry,
-                gmx,
-                grx,
-                gmy,
-                gry,
-                omx,
-                orx,
-                omy,
-                ory,
-            )
-
-            out, s = self.RUN_RWKV_6(
-                1, T, self.args.n_att, H, ss, r, k, v, w=w, u=t_first
-            )
-            return self.v5_2_after(
-                t_decay, out, s, x, xxx, g, lx_w, lx_b, ow, omx, orx, omy, ory
-            )
-
-    ########################################################################################################
-
-    def forward(self, tokens, state, full_output=False):
-        with torch.no_grad():
-            w = self.w
-            args = self.args
-
-            if state == None:
-                if self.version == 4:
-                    state = [None] * args.n_layer * 5
-                    for i in range(
-                        args.n_layer
-                    ):  # state: 0=att_xx 1=att_aa 2=att_bb 3=att_pp 4=ffn_xx
-                        dd = self.strategy[i]
-                        dev = dd.device
-                        atype = dd.atype
-                        state[i * 5 + 0] = torch.zeros(
-                            args.n_embd, dtype=atype, requires_grad=False, device=dev
-                        ).contiguous()
-                        state[i * 5 + 1] = torch.zeros(
-                            args.n_att,
-                            dtype=torch.float,
-                            requires_grad=False,
-                            device=dev,
-                        ).contiguous()
-                        state[i * 5 + 2] = torch.zeros(
-                            args.n_att,
-                            dtype=torch.float,
-                            requires_grad=False,
-                            device=dev,
-                        ).contiguous()
-                        state[i * 5 + 3] = (
-                            torch.zeros(
-                                args.n_att,
-                                dtype=torch.float,
-                                requires_grad=False,
-                                device=dev,
-                            ).contiguous()
-                            - 1e30
-                        )
-                        state[i * 5 + 4] = torch.zeros(
-                            args.n_embd, dtype=atype, requires_grad=False, device=dev
-                        ).contiguous()
-                elif int(self.version) in [5, 6]:
-                    state = [None] * args.n_layer * 3
-                    for i in range(args.n_layer):  # state: 0=att_xx 1=att_kv 2=ffn_xx
-                        dd = self.strategy[i]
-                        dev = dd.device
-                        atype = dd.atype
-                        state[i * 3 + 0] = torch.zeros(
-                            args.n_embd, dtype=atype, requires_grad=False, device=dev
-                        ).contiguous()
-                        if args.time_state:
-                            state[i * 3 + 1] = (
-                                w[f"blocks.{i}.att.time_state"]
-                                .transpose(1, 2)
-                                .to(dtype=torch.float, device=dev)
-                                .requires_grad_(False)
-                                .contiguous()
-                            )
-                        else:
-                            state[i * 3 + 1] = torch.zeros(
-                                (
-                                    args.n_head,
-                                    args.n_att // args.n_head,
-                                    args.n_att // args.n_head,
-                                ),
-                                dtype=torch.float,
-                                requires_grad=False,
-                                device=dev,
-                            ).contiguous()
-                        state[i * 3 + 2] = torch.zeros(
-                            args.n_embd, dtype=atype, requires_grad=False, device=dev
-                        ).contiguous()
-
-            seq_mode = len(tokens) > 1
-
-            x = w["emb.weight"][tokens if seq_mode else tokens[0]]
-
-            for i in range(args.n_layer):
-                bbb = f"blocks.{i}."
-                att = f"blocks.{i}.att."
-                ffn = f"blocks.{i}.ffn."
-                dd = self.strategy[i]
-                dev = dd.device
-                atype = dd.atype
-                wtype = dd.wtype
-                if seq_mode:
-                    cuda_applicable = os.environ[
-                        "RWKV_CUDA_ON"
-                    ] == "1" and "cuda" in str(dev)
-                    if cuda_applicable:
-                        ATT = self.cuda_att_seq
-                    else:
-                        ATT = self.att_seq
-                    if self.version == 5:
-                        ATT = self.att_seq_v5
-                    elif self.version == 5.1:
-                        ATT = self.att_seq_v5_1
-                    elif self.version == 5.2:
-                        ATT = self.att_seq_v5_2
-                        if cuda_applicable:
-                            ATT = self.cuda_att_seq_v5_2
-                    elif self.version == 6.0:
-                        ATT = self.att_seq_v6_0
-                        if cuda_applicable:
-                            ATT = self.cuda_att_seq_v6_0
-                    FFN = self.ffn_seq
-                    if self.version >= 6.0:
-                        FFN = self.ffn_seq_v6
-                else:
-                    ATT = self.att_one
-                    if self.version == 5:
-                        ATT = self.att_one_v5
-                    elif self.version == 5.1:
-                        ATT = self.att_one_v5_1
-                    elif self.version == 5.2:
-                        ATT = self.att_one_v5_1  # same as v5.1
-                    elif self.version == 6.0:
-                        ATT = self.att_one_v6_0
-                    FFN = self.ffn_one
-                    if self.version >= 6.0:
-                        FFN = self.ffn_one_v6
-
-                x = x.to(dtype=atype, device=dev)
-
-                kw = w[f"{att}key.weight"]
-                vw = w[f"{att}value.weight"]
-                rw = w[f"{att}receptance.weight"]
-                ow = w[f"{att}output.weight"]
-                if dd.stream:
-                    kw = kw.to(device=dev, non_blocking=True)
-                    vw = vw.to(device=dev, non_blocking=True)
-                    rw = rw.to(device=dev, non_blocking=True)
-                    ow = ow.to(device=dev, non_blocking=True)
-                kmx = w[f"{att}key.weight_mx"] if wtype == torch.uint8 else x
-                krx = w[f"{att}key.weight_rx"] if wtype == torch.uint8 else x
-                kmy = w[f"{att}key.weight_my"] if wtype == torch.uint8 else x
-                kry = w[f"{att}key.weight_ry"] if wtype == torch.uint8 else x
-                vmx = w[f"{att}value.weight_mx"] if wtype == torch.uint8 else x
-                vrx = w[f"{att}value.weight_rx"] if wtype == torch.uint8 else x
-                vmy = w[f"{att}value.weight_my"] if wtype == torch.uint8 else x
-                vry = w[f"{att}value.weight_ry"] if wtype == torch.uint8 else x
-                rmx = w[f"{att}receptance.weight_mx"] if wtype == torch.uint8 else x
-                rrx = w[f"{att}receptance.weight_rx"] if wtype == torch.uint8 else x
-                rmy = w[f"{att}receptance.weight_my"] if wtype == torch.uint8 else x
-                rry = w[f"{att}receptance.weight_ry"] if wtype == torch.uint8 else x
-                omx = w[f"{att}output.weight_mx"] if wtype == torch.uint8 else x
-                orx = w[f"{att}output.weight_rx"] if wtype == torch.uint8 else x
-                omy = w[f"{att}output.weight_my"] if wtype == torch.uint8 else x
-                ory = w[f"{att}output.weight_ry"] if wtype == torch.uint8 else x
-                if self.version in [5.1, 5.2, 6.0]:
-                    gw = w[f"{att}gate.weight"]
-                    if dd.stream:
-                        gw = gw.to(device=dev, non_blocking=True)
-                    gmx = w[f"{att}gate.weight_mx"] if wtype == torch.uint8 else x
-                    grx = w[f"{att}gate.weight_rx"] if wtype == torch.uint8 else x
-                    gmy = w[f"{att}gate.weight_my"] if wtype == torch.uint8 else x
-                    gry = w[f"{att}gate.weight_ry"] if wtype == torch.uint8 else x
-                if self.version == 4:
-                    (
-                        x,
-                        state[i * 5 + 0],
-                        state[i * 5 + 1],
-                        state[i * 5 + 2],
-                        state[i * 5 + 3],
-                    ) = ATT(
-                        x,
-                        state[i * 5 + 0],
-                        state[i * 5 + 1],
-                        state[i * 5 + 2],
-                        state[i * 5 + 3],
-                        w[f"{bbb}ln1.weight"],
-                        w[f"{bbb}ln1.bias"],
-                        w[f"{att}time_mix_k"],
-                        w[f"{att}time_mix_v"],
-                        w[f"{att}time_mix_r"],
-                        w[f"{att}time_decay"],
-                        w[f"{att}time_first"],
-                        kw,
-                        vw,
-                        rw,
-                        ow,
-                        kmx,
-                        krx,
-                        kmy,
-                        kry,
-                        vmx,
-                        vrx,
-                        vmy,
-                        vry,
-                        rmx,
-                        rrx,
-                        rmy,
-                        rry,
-                        omx,
-                        orx,
-                        omy,
-                        ory,
-                    )
-                elif self.version == 5:
-                    x, state[i * 3 + 0], state[i * 3 + 1] = ATT(
-                        x,
-                        state[i * 3 + 0],
-                        state[i * 3 + 1],
-                        w[f"{bbb}ln1.weight"],
-                        w[f"{bbb}ln1.bias"],
-                        w[f"{att}ln_x.weight"],
-                        w[f"{att}ln_x.bias"],
-                        w[f"{att}time_mix_k"],
-                        w[f"{att}time_mix_v"],
-                        w[f"{att}time_mix_r"],
-                        w[f"{att}time_decay"],
-                        w[f"{att}time_first"],
-                        kw,
-                        vw,
-                        rw,
-                        ow,
-                        kmx,
-                        krx,
-                        kmy,
-                        kry,
-                        vmx,
-                        vrx,
-                        vmy,
-                        vry,
-                        rmx,
-                        rrx,
-                        rmy,
-                        rry,
-                        omx,
-                        orx,
-                        omy,
-                        ory,
-                    )
-                elif self.version in [5.1, 5.2]:
-                    x, state[i * 3 + 0], state[i * 3 + 1] = ATT(
-                        x,
-                        state[i * 3 + 0],
-                        state[i * 3 + 1],
-                        w[f"{bbb}ln1.weight"],
-                        w[f"{bbb}ln1.bias"],
-                        w[f"{att}ln_x.weight"],
-                        w[f"{att}ln_x.bias"],
-                        w[f"{att}time_mix_k"],
-                        w[f"{att}time_mix_v"],
-                        w[f"{att}time_mix_r"],
-                        w[f"{att}time_mix_g"],
-                        w[f"{att}time_decay"],
-                        w[f"{att}time_first"],
-                        kw,
-                        vw,
-                        rw,
-                        gw,
-                        ow,
-                        kmx,
-                        krx,
-                        kmy,
-                        kry,
-                        vmx,
-                        vrx,
-                        vmy,
-                        vry,
-                        rmx,
-                        rrx,
-                        rmy,
-                        rry,
-                        gmx,
-                        grx,
-                        gmy,
-                        gry,
-                        omx,
-                        orx,
-                        omy,
-                        ory,
-                    )
-                elif self.version == 6.0:
-                    x, state[i * 3 + 0], state[i * 3 + 1] = ATT(
-                        x,
-                        state[i * 3 + 0],
-                        state[i * 3 + 1],
-                        w[f"{bbb}ln1.weight"],
-                        w[f"{bbb}ln1.bias"],
-                        w[f"{att}ln_x.weight"],
-                        w[f"{att}ln_x.bias"],
-                        w[f"{att}time_maa_x"],
-                        w[f"{att}time_maa_w"],
-                        w[f"{att}time_maa_k"],
-                        w[f"{att}time_maa_v"],
-                        w[f"{att}time_maa_r"],
-                        w[f"{att}time_maa_g"],
-                        w[f"{att}time_maa_w1"],
-                        w[f"{att}time_maa_w2"],
-                        w[f"{att}time_decay_w1"],
-                        w[f"{att}time_decay_w2"],
-                        w[f"{att}time_decay"],
-                        w[f"{att}time_first"],
-                        kw,
-                        vw,
-                        rw,
-                        gw,
-                        ow,
-                        kmx,
-                        krx,
-                        kmy,
-                        kry,
-                        vmx,
-                        vrx,
-                        vmy,
-                        vry,
-                        rmx,
-                        rrx,
-                        rmy,
-                        rry,
-                        gmx,
-                        grx,
-                        gmy,
-                        gry,
-                        omx,
-                        orx,
-                        omy,
-                        ory,
-                    )
-                if dd.stream:
-                    del kw, vw, rw, ow
-                    if self.version in [5.1, 5.2, 6.0]:
-                        del gw
-
-                kw = w[f"{ffn}key.weight"]
-                vw = w[f"{ffn}value.weight"]
-                rw = w[f"{ffn}receptance.weight"]
-                if dd.stream:
-                    kw = kw.to(device=dev, non_blocking=True)
-                    vw = vw.to(device=dev, non_blocking=True)
-                    rw = rw.to(device=dev, non_blocking=True)
-                kmx = w[f"{ffn}key.weight_mx"] if wtype == torch.uint8 else x
-                krx = w[f"{ffn}key.weight_rx"] if wtype == torch.uint8 else x
-                kmy = w[f"{ffn}key.weight_my"] if wtype == torch.uint8 else x
-                kry = w[f"{ffn}key.weight_ry"] if wtype == torch.uint8 else x
-                vmx = w[f"{ffn}value.weight_mx"] if wtype == torch.uint8 else x
-                vrx = w[f"{ffn}value.weight_rx"] if wtype == torch.uint8 else x
-                vmy = w[f"{ffn}value.weight_my"] if wtype == torch.uint8 else x
-                vry = w[f"{ffn}value.weight_ry"] if wtype == torch.uint8 else x
-                rmx = w[f"{ffn}receptance.weight_mx"] if wtype == torch.uint8 else x
-                rrx = w[f"{ffn}receptance.weight_rx"] if wtype == torch.uint8 else x
-                rmy = w[f"{ffn}receptance.weight_my"] if wtype == torch.uint8 else x
-                rry = w[f"{ffn}receptance.weight_ry"] if wtype == torch.uint8 else x
-                if self.version == 4:
-                    offset = i * 5 + 4
-                elif int(self.version) in [5, 6]:
-                    offset = i * 3 + 2
-                if self.version < 6.0:
-                    x, state[offset] = FFN(
-                        x,
-                        state[offset],
-                        w[f"{bbb}ln2.weight"],
-                        w[f"{bbb}ln2.bias"],
-                        w[f"{ffn}time_mix_k"],
-                        w[f"{ffn}time_mix_r"],
-                        kw,
-                        vw,
-                        rw,
-                        kmx,
-                        krx,
-                        kmy,
-                        kry,
-                        vmx,
-                        vrx,
-                        vmy,
-                        vry,
-                        rmx,
-                        rrx,
-                        rmy,
-                        rry,
-                    )
-                else:
-                    x, state[offset] = FFN(
-                        x,
-                        state[offset],
-                        w[f"{bbb}ln2.weight"],
-                        w[f"{bbb}ln2.bias"],
-                        w[f"{ffn}time_maa_k"],
-                        w[f"{ffn}time_maa_r"],
-                        kw,
-                        vw,
-                        rw,
-                        kmx,
-                        krx,
-                        kmy,
-                        kry,
-                        vmx,
-                        vrx,
-                        vmy,
-                        vry,
-                        rmx,
-                        rrx,
-                        rmy,
-                        rry,
-                    )
-                if dd.stream:
-                    del kw, vw, rw
-
-                if self.RESCALE_LAYER > 0:
-                    if (i + 1) % self.RESCALE_LAYER == 0:
-                        x = x / 2
-
-            dd = self.strategy[args.n_layer]
-            x = x[-1, :] if (seq_mode and (not full_output)) else x
-            x = x.to(dtype=dd.atype, device=dd.device)
-
-            x = F.layer_norm(
-                x, (args.n_embd,), weight=w["ln_out.weight"], bias=w["ln_out.bias"]
-            )
-            if w["head.weight"].dtype != torch.uint8:
-                x = x @ w["head.weight"]
-            else:
-                if seq_mode and full_output:
-                    x = mm8_seq(
-                        x,
-                        w["head.weight"],
-                        w["head.weight_mx"],
-                        w["head.weight_rx"],
-                        w["head.weight_my"],
-                        w["head.weight_ry"],
-                    )
-                else:
-                    x = mm8_one(
-                        x,
-                        w["head.weight"],
-                        w["head.weight_mx"],
-                        w["head.weight_rx"],
-                        w["head.weight_my"],
-                        w["head.weight_ry"],
-                    )
-
-            return x.float(), state
-
-
-if os.environ.get("RWKV_V7_ON") == "1":
-    RWKV = RWKV_x070