File size: 19,553 Bytes
c1a12af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 |
import torch
from einops import rearrange
from .hybrid_cache import TimeMixState, BlockState
import math
import torch.nn as nn
from torch.nn import functional as F
from .configuration_rwkv_hybrid import RwkvHybridConfig
try:
import triton
from rwkvfla.ops.rwkv7 import (
fused_recurrent_rwkv7,
chunk_rwkv7,
native_recurrent_rwkv7,
) # pylint: disable=C0411
from rwkvfla.ops.rwkv6 import (
fused_recurrent_rwkv6,
chunk_rwkv6,
native_recurrent_rwkv6,
)
except ImportError:
from rwkvfla.ops.rwkv7 import native_recurrent_rwkv7 # pylint: disable=C0411
from rwkvfla.ops.rwkv6 import native_recurrent_rwkv6
fused_recurrent_rwkv7 = native_recurrent_rwkv7
chunk_rwkv7 = native_recurrent_rwkv7
chunk_rwkv6 = native_recurrent_rwkv6
fused_recurrent_rwkv6 = native_recurrent_rwkv6
class Rwkv_Tmix_x070(nn.Module):
def __init__(self, args: RwkvHybridConfig, layer_id, update_v_first, get_v_first):
super().__init__()
self.args = args
self.layer_id = layer_id
self.hidden_size = args.hidden_size
self.update_v_first = update_v_first
self.get_v_first = get_v_first
self.head_size = args.head_size
self.n_head = args.num_wkv_heads
assert args.hidden_size % self.n_head == 0
H = self.n_head
N = self.head_size
self.x_r = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
self.x_w = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
self.x_k = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
self.x_v = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
self.x_a = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
self.x_g = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
D_DECAY_LORA = 64
D_AAA_LORA = 64
D_MV_LORA = 32
D_GATE_LORA = 128
self.w1 = nn.Parameter(torch.Tensor(args.hidden_size, D_DECAY_LORA))
self.w2 = nn.Parameter(torch.Tensor(D_DECAY_LORA, args.hidden_size))
self.w0 = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
self.a1 = nn.Parameter(torch.Tensor(args.hidden_size, D_AAA_LORA))
self.a2 = nn.Parameter(torch.Tensor(D_AAA_LORA, args.hidden_size))
self.a0 = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
self.v1 = nn.Parameter(torch.Tensor(args.hidden_size, D_MV_LORA))
self.v2 = nn.Parameter(torch.Tensor(D_MV_LORA, args.hidden_size))
self.v0 = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
if self.args.wkv_has_gate:
self.g1 = nn.Parameter(torch.Tensor(args.hidden_size, D_GATE_LORA))
self.g2 = nn.Parameter(torch.Tensor(D_GATE_LORA, args.hidden_size))
self.k_k = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
self.k_a = nn.Parameter(torch.Tensor(1, 1, args.hidden_size))
self.r_k = nn.Parameter(torch.Tensor(H, N))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.receptance = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.key = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.value = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
self.output = nn.Linear(args.hidden_size, args.hidden_size, bias=False)
if self.args.wkv_has_group_norm:
self.ln_x = nn.GroupNorm(
H, args.hidden_size, eps=(1e-5) * (args.head_size_divisor**2)
)
def post_init(self):
with torch.no_grad():
ratio_0_to_1 = self.layer_id / (self.args.num_hidden_layers - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (
self.layer_id / self.args.num_hidden_layers
) # 1 to ~0
ddd = torch.ones(1, 1, self.args.hidden_size)
for i in range(self.args.hidden_size):
ddd[0, 0, i] = i / self.args.hidden_size
nn.init.constant_(self.x_r, 1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
nn.init.constant_(self.x_w, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
nn.init.constant_(
self.x_k,
1.0 - (torch.pow(ddd, 0.9 * ratio_1_to_almost0) + 0.4 * ratio_0_to_1),
)
nn.init.constant_(
self.x_v,
1.0 - (torch.pow(ddd, 0.4 * ratio_1_to_almost0) + 0.6 * ratio_0_to_1),
)
nn.init.constant_(self.x_a, 1.0 - torch.pow(ddd, 0.9 * ratio_1_to_almost0))
nn.init.constant_(self.x_g, 1.0 - torch.pow(ddd, 0.2 * ratio_1_to_almost0))
def ortho_init(x, scale):
shape = x.shape
original_dtype = x.dtype
x_fp32 = x.float()
if len(shape) == 2:
gain = math.sqrt(shape[0] / shape[1]) if shape[0] > shape[1] else 1
nn.init.orthogonal_(x_fp32, gain=gain * scale)
elif len(shape) == 3:
gain = math.sqrt(shape[1] / shape[2]) if shape[1] > shape[2] else 1
for i in range(shape[0]):
nn.init.orthogonal_(x_fp32[i], gain=gain * scale)
else:
raise ValueError("ortho_init only supports 2D or 3D tensors")
x.data.copy_(x_fp32.to(original_dtype))
return x
D_DECAY_LORA = 64
nn.init.zeros_(self.w1)
self.w2 = nn.Parameter(
ortho_init(torch.zeros(D_DECAY_LORA, self.args.hidden_size), 0.1)
)
decay_speed = torch.ones(self.args.hidden_size)
for n in range(self.args.hidden_size):
decay_speed[n] = -7 + 5 * (n / (self.args.hidden_size - 1)) ** (
0.85 + 1.0 * ratio_0_to_1**0.5
)
nn.init.constant_(
self.w0, decay_speed.reshape(1, 1, self.args.hidden_size) + 0.5
)
D_AAA_LORA = 64
nn.init.zeros_(self.a1)
self.a2 = nn.Parameter(
ortho_init(torch.zeros(D_AAA_LORA, self.args.hidden_size), 0.1)
)
nn.init.zeros_(self.a0)
D_MV_LORA = 32
nn.init.zeros_(self.v1)
self.v2 = nn.Parameter(
ortho_init(torch.zeros(D_MV_LORA, self.args.hidden_size), 0.1)
)
nn.init.constant_(self.v0, 1.0)
D_GATE_LORA = 128
if self.args.wkv_has_gate:
nn.init.zeros_(self.g1)
self.g2 = nn.Parameter(
ortho_init(torch.zeros(D_GATE_LORA, self.args.hidden_size), 0.1)
)
nn.init.constant_(self.k_k, 0.85)
nn.init.constant_(self.k_a, 1.0)
nn.init.zeros_(self.r_k)
nn.init.zeros_(self.receptance.weight)
nn.init.zeros_(self.key.weight)
nn.init.zeros_(self.value.weight)
nn.init.zeros_(self.output.weight)
if self.args.wkv_has_group_norm:
nn.init.ones_(self.ln_x.weight)
nn.init.zeros_(self.ln_x.bias)
def apply_wkv7_state(self, r, k, v, w, a, b, s):
r = rearrange(r, "b l (h d) -> b h l d", h=self.n_head)
k = rearrange(k, "b l (h d) -> b h l d", h=self.n_head)
v = rearrange(v, "b l (h d) -> b h l d", h=self.n_head)
w = rearrange(w, "b l (h d) -> b h l d", h=self.n_head)
a = rearrange(a, "b l (h d) -> b h l d", h=self.n_head)
b = rearrange(b, "b l (h d) -> b h l d", h=self.n_head)
if r.device.type == "cpu":
o, state = native_recurrent_rwkv7(
r,
k,
v,
w,
a,
b,
scale=1.0,
initial_state=s.transpose(-1, -2),
output_final_state=True,
use_log_w=False,
head_first=True,
)
state = state.transpose(-1, -2)
elif self.training:
o, state = chunk_rwkv7(
r,
k,
v,
w,
a,
b,
scale=1.0,
initial_state=s,
output_final_state=True,
use_log_w=False,
head_first=True,
)
else:
o, state = fused_recurrent_rwkv7(
r,
k,
v,
w,
a,
b,
scale=1.0,
initial_state=s,
output_final_state=True,
use_log_w=False,
head_first=True,
)
x = rearrange(o, "b h l d -> b l (h d)")
return x, state
def forward(self, x, last_state: TimeMixState):
shift_state = last_state.shift_state
B, T, C = x.size()
H = self.n_head
if shift_state is not None:
xx = torch.concat((shift_state.unsqueeze(1), x[:, :-1]), dim=1) - x
else:
xx = self.time_shift(x) - x
lx = x[:, -1]
xr = x + xx * self.x_r
xw = x + xx * self.x_w
xk = x + xx * self.x_k
xv = x + xx * self.x_v
xa = x + xx * self.x_a
xg = x + xx * self.x_g
r = self.receptance(xr)
w = (
-F.softplus(-(self.w0 + torch.tanh(xw @ self.w1) @ self.w2)) - 0.5
) # soft-clamp to (-inf, -0.5)
k = self.key(xk)
v = self.value(xv)
if self.layer_id == 0:
self.update_v_first(v)
else:
# Original implementation
v = v + (self.get_v_first().to(v.device) - v) * torch.sigmoid(
self.v0 + (xv @ self.v1) @ self.v2
) # add value residual
a = torch.sigmoid(
self.a0 + (xa @ self.a1) @ self.a2
) # a is "in-context learning rate"
if self.args.wkv_has_gate:
g = torch.sigmoid(xg @ self.g1) @ self.g2
kk = k * self.k_k
kk = F.normalize(kk.view(B, T, H, -1), dim=-1, p=2.0).view(B, T, C)
k = k * (1 + (a - 1) * self.k_a)
wkv_state = last_state.wkv_state
x, wkv_state = self.apply_wkv7_state(
r,
k,
v,
w,
-kk,
(kk * a),
s=wkv_state,
)
if self.args.wkv_has_group_norm:
x = self.ln_x(x.view(B * T, C)).view(B, T, C)
x = x + (
(r.view(B, T, H, -1) * k.view(B, T, H, -1) * self.r_k).sum(
dim=-1, keepdim=True
)
* v.view(B, T, H, -1)
).view(B, T, C)
x = self.output(x * g) if self.args.wkv_has_gate else self.output(x)
return x, TimeMixState(lx, wkv_state)
class Rwkv7Attention(nn.Module):
def __init__(self, args: RwkvHybridConfig, layer_id, update_v_first, get_v_first):
super().__init__()
self.args = args
self.layer_idx = layer_id
self.time_mixer = Rwkv_Tmix_x070(args, layer_id, update_v_first, get_v_first)
def forward(self, hidden_states, past_key_value, **kwargs):
attn_output = hidden_states
batch_size, token_length, _ = attn_output.size()
if past_key_value is not None and len(past_key_value) > self.layer_idx:
last_state = past_key_value[self.layer_idx][0]
else:
last_state = self.init_state(
batch_size, attn_output.device, attn_output.dtype
)
attn_output, states = self.time_mixer(attn_output, last_state.time_mix_state)
last_state.time_mix_state = states
if past_key_value is not None:
past_key_value.update(token_length, last_state, self.layer_idx)
return attn_output, None
def init_state(self, batch_size, device, dtype) -> BlockState:
wkv_states = torch.zeros(
(
batch_size,
self.args.num_wkv_heads,
self.args.head_size,
self.args.head_size,
),
device=device,
dtype=torch.float32,
)
token_shift = torch.zeros(
(batch_size, self.args.hidden_size), device=device, dtype=dtype
)
return BlockState(TimeMixState(token_shift, wkv_states), None)
class Rwkv_Tmix_x060(nn.Module):
def __init__(self, args: RwkvHybridConfig, layer_id, **kwargs):
super().__init__()
self.args = args
self.layer_id = layer_id
self.hidden_size = args.hidden_size
self.head_size = args.head_size
self.n_head = args.num_wkv_heads
assert args.hidden_size % self.n_head == 0
H = self.n_head
N = self.head_size
with torch.no_grad():
ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
ddd = torch.ones(1, 1, args.hidden_size)
for i in range(args.hidden_size):
ddd[0, 0, i] = i / args.hidden_size
# fancy time_mix
self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
self.time_maa_v = nn.Parameter(
1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
)
self.time_maa_r = nn.Parameter(
1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
)
self.time_maa_g = nn.Parameter(
1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)
)
D_MIX_LORA = 32 # generate TIME_MIX for w,k,v,r,g
if args.hidden_size == 4096:
D_MIX_LORA = D_MIX_LORA * 2
self.time_maa_w1 = nn.Parameter(
torch.zeros(args.hidden_size, D_MIX_LORA * 5)
)
self.time_maa_w2 = nn.Parameter(
torch.zeros(5, D_MIX_LORA, args.hidden_size).uniform_(-0.01, 0.01)
)
# fancy time_decay
decay_speed = torch.ones(args.head_size)
for n in range(args.head_size):
decay_speed[n] = -6 + 5 * (n / (args.head_size - 1)) ** (
0.7 + 1.3 * ratio_0_to_1
)
self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.head_size))
D_DECAY_LORA = 64
if args.hidden_size == 4096:
D_DECAY_LORA = D_DECAY_LORA * 2
self.time_decay_w1 = nn.Parameter(
torch.zeros(args.hidden_size, D_DECAY_LORA)
)
self.time_decay_w2 = nn.Parameter(
torch.zeros(D_DECAY_LORA, args.head_size).uniform_(-0.01, 0.01)
)
tmp = torch.zeros(args.head_size)
for n in range(args.head_size):
zigzag = ((n + 1) % 3 - 1) * 0.1
tmp[n] = ratio_0_to_1 * (1 - (n / (args.head_size - 1))) + zigzag
self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
# self.time_state = nn.Parameter(torch.zeros(self.n_head, self.head_size, self.head_size))
self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
self.receptance = nn.Linear(args.hidden_size, args.head_size, bias=False)
self.key = nn.Linear(args.hidden_size, args.head_size, bias=False)
self.value = nn.Linear(args.hidden_size, args.head_size, bias=False)
self.output = nn.Linear(args.head_size, args.hidden_size, bias=False)
self.gate = nn.Linear(args.hidden_size, args.head_size, bias=False)
if self.args.wkv_has_group_norm:
self.ln_x = nn.GroupNorm(
self.n_head, args.head_size, eps=(1e-5) * (args.head_size_divisor**2)
)
def post_init(self):
pass
def forward(self, x, last_state: TimeMixState):
shift_state = last_state.shift_state
B, T, C = x.size()
H = self.n_head
if shift_state is not None:
xx = torch.concat((shift_state.unsqueeze(1), x[:, :-1]), dim=1) - x
else:
xx = self.time_shift(x) - x
lx = x[:, -1]
xxx = x + xx * self.time_maa_x
xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
mw, mk, mv, mr, mg = xxx.unbind(dim=0)
xw = x + xx * (self.time_maa_w + mw)
xk = x + xx * (self.time_maa_k + mk)
xv = x + xx * (self.time_maa_v + mv)
xr = x + xx * (self.time_maa_r + mr)
xg = x + xx * (self.time_maa_g + mg)
r = self.receptance(xr)
k = self.key(xk)
v = self.value(xv)
g = F.silu(self.gate(xg))
ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
w = self.time_decay + ww
wkv_state = last_state.wkv_state
x, wkv_state = self.apply_wkv6_state(
B, T, C, H, r, k, v, w, u=self.time_faaaa, s=wkv_state
)
if self.args.wkv_has_group_norm:
x = self.ln_x(x.view(B * T, C)).view(B, T, C)
x = self.output(x * g)
return x, TimeMixState(lx, wkv_state)
def apply_wkv6_state(self, B, T, C, H, r, k, v, w, u, s):
r = rearrange(r, "b l (h d) -> b h l d", h=H)
k = rearrange(k, "b l (h d) -> b h l d", h=H)
v = rearrange(v, "b l (h d) -> b h l d", h=H)
w = rearrange(w, "b l (h d) -> b h l d", h=H)
if r.device.type == "cpu":
wkv6_func = native_recurrent_rwkv6
elif self.training:
wkv6_func = chunk_rwkv6
else:
wkv6_func = fused_recurrent_rwkv6
o, state = wkv6_func(
r,
k,
v,
-torch.exp(w),
u=u,
scale=1.0,
initial_state=s,
output_final_state=True,
)
x = rearrange(o, "b h l d -> b l (h d)")
return x, state
class Rwkv6Attention(nn.Module):
def __init__(self, args: RwkvHybridConfig, layer_id, **kwargs):
super().__init__()
self.args = args
self.layer_idx = layer_id
self.time_mixer = Rwkv_Tmix_x060(args, layer_id, **kwargs)
def forward(self, hidden_states, past_key_value, **kwargs):
attn_output = hidden_states
B, T, C = attn_output.size()
if past_key_value is not None:
if len(past_key_value) <= self.layer_idx:
last_state = None
else:
last_state = past_key_value[self.layer_idx][0]
if last_state is None:
wkv_states = torch.zeros(
(B, self.args.num_wkv_heads, self.args.head_size, self.args.head_size),
device=attn_output.device,
dtype=torch.float32,
)
token_shift = torch.zeros(
(B, C), device=attn_output.device, dtype=attn_output.dtype
)
time_state = TimeMixState(token_shift, wkv_states)
channel_state = None
last_state = BlockState(time_state, channel_state)
attn_output, states = self.time_mixer(attn_output, last_state.time_mix_state)
last_state.time_mix_state = states
if past_key_value is not None:
past_key_value.update(T, last_state, self.layer_idx)
return attn_output, None
|