Commit
·
df388cc
1
Parent(s):
c572a14
Corrected rotary embedding
Browse files- attention.py +36 -16
attention.py
CHANGED
@@ -28,7 +28,7 @@ class RotaryEmbedding(nn.Module):
|
|
28 |
d_rotary: int,
|
29 |
rotary_base: float = 10000.0,
|
30 |
initial_cos_sin_cache_len: int = 2048,
|
31 |
-
device: torch.device =
|
32 |
) -> None:
|
33 |
super().__init__()
|
34 |
self.d_rotary = d_rotary
|
@@ -37,31 +37,37 @@ class RotaryEmbedding(nn.Module):
|
|
37 |
self.dtype = torch.float32
|
38 |
self._update_cos_sin_cache(seqlen=initial_cos_sin_cache_len)
|
39 |
|
40 |
-
def _update_cos_sin_cache(
|
|
|
|
|
|
|
|
|
|
|
41 |
# only call this function when seqlen is larger than _max_seqlen
|
42 |
self._max_seqlen = seqlen
|
43 |
|
44 |
# m * theta_i = m * base^(-2i/d) = m * (1 / base^(2i/d)), where i in [1, d/2]
|
45 |
m = torch.arange(
|
46 |
seqlen,
|
47 |
-
device=
|
48 |
-
dtype=
|
49 |
)
|
50 |
theta_i = 1.0 / (
|
51 |
self.rotary_base ** (
|
52 |
torch.arange(
|
53 |
start=0,
|
54 |
end=self.d_rotary,
|
55 |
-
|
56 |
-
|
|
|
57 |
) / self.d_rotary
|
58 |
)
|
59 |
)
|
60 |
# torch.outer, since torch.einsum converts from fp32 to fp16 if used with torch.amp
|
61 |
# TODO: does this matter if I'm disabling torch.autocast?
|
62 |
m_theta_i = torch.outer(m, theta_i)
|
63 |
-
self._cos_cached = torch.cos(m_theta_i).to(
|
64 |
-
self._sin_cached = torch.sin(m_theta_i).to(
|
65 |
|
66 |
# TODO: scale_base caching is labelled as not yet done in Phi2
|
67 |
"""
|
@@ -90,14 +96,17 @@ class RotaryEmbedding(nn.Module):
|
|
90 |
sin: torch.FloatTensor, # dim: (_max_seqlen, d_rotary)
|
91 |
) -> torch.FloatTensor:
|
92 |
seqlen = x.shape[1]
|
93 |
-
|
|
|
|
|
94 |
broadcast_rearrange = "s d -> s 1 d" if x1.ndim == 4 else "s d -> s 1 1 d"
|
95 |
c, s = rearrange(cos[:seqlen], broadcast_rearrange), rearrange(sin[:seqlen], broadcast_rearrange)
|
96 |
x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] # make sure rotary embedding is in float32
|
97 |
-
|
98 |
torch.FloatTensor,
|
99 |
torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], dim=-1).to(x.dtype)
|
100 |
)
|
|
|
101 |
|
102 |
def forward(
|
103 |
self,
|
@@ -107,9 +116,11 @@ class RotaryEmbedding(nn.Module):
|
|
107 |
if (
|
108 |
not self._max_seqlen
|
109 |
or self._max_seqlen < x.shape[1] + seqlen_offset
|
|
|
|
|
110 |
or (self.training and self._cos_cached.is_inference())
|
111 |
):
|
112 |
-
self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset)
|
113 |
return self._apply_rotary_emb_qkv(
|
114 |
x,
|
115 |
cast(torch.FloatTensor, self._cos_cached[seqlen_offset:]),
|
@@ -269,7 +280,8 @@ class MHA(nn.Module):
|
|
269 |
else RotaryEmbedding
|
270 |
)
|
271 |
self.rotary_emb = rotary_cls(
|
272 |
-
d_rotary=math.ceil((d_embedding // n_attn_heads) / 2), # d_rotary is half of d_head
|
|
|
273 |
initial_cos_sin_cache_len=initial_cos_sin_cache_len,
|
274 |
)
|
275 |
|
@@ -378,12 +390,20 @@ class MHA(nn.Module):
|
|
378 |
kv_cache: KVCache,
|
379 |
key_padding_mask: torch.BoolTensor | None,
|
380 |
) -> torch.FloatTensor:
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
seqlen_offset = 0 if kv_cache is None else kv_cache.seqlen_offset,
|
385 |
)
|
386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
self._update_kv_cache(kv, kv_cache, self.block_n)
|
388 |
causal = False # turning off causal mask for cross attention
|
389 |
|
|
|
28 |
d_rotary: int,
|
29 |
rotary_base: float = 10000.0,
|
30 |
initial_cos_sin_cache_len: int = 2048,
|
31 |
+
device: torch.device | None = None,
|
32 |
) -> None:
|
33 |
super().__init__()
|
34 |
self.d_rotary = d_rotary
|
|
|
37 |
self.dtype = torch.float32
|
38 |
self._update_cos_sin_cache(seqlen=initial_cos_sin_cache_len)
|
39 |
|
40 |
+
def _update_cos_sin_cache(
|
41 |
+
self,
|
42 |
+
seqlen: int,
|
43 |
+
device: str | None = None,
|
44 |
+
dtype: torch.dtype | None = None,
|
45 |
+
) -> None:
|
46 |
# only call this function when seqlen is larger than _max_seqlen
|
47 |
self._max_seqlen = seqlen
|
48 |
|
49 |
# m * theta_i = m * base^(-2i/d) = m * (1 / base^(2i/d)), where i in [1, d/2]
|
50 |
m = torch.arange(
|
51 |
seqlen,
|
52 |
+
device=device,
|
53 |
+
dtype=torch.float32,
|
54 |
)
|
55 |
theta_i = 1.0 / (
|
56 |
self.rotary_base ** (
|
57 |
torch.arange(
|
58 |
start=0,
|
59 |
end=self.d_rotary,
|
60 |
+
step=2,
|
61 |
+
device=device,
|
62 |
+
dtype=torch.float32,
|
63 |
) / self.d_rotary
|
64 |
)
|
65 |
)
|
66 |
# torch.outer, since torch.einsum converts from fp32 to fp16 if used with torch.amp
|
67 |
# TODO: does this matter if I'm disabling torch.autocast?
|
68 |
m_theta_i = torch.outer(m, theta_i)
|
69 |
+
self._cos_cached = torch.cos(m_theta_i).to(dtype)
|
70 |
+
self._sin_cached = torch.sin(m_theta_i).to(dtype)
|
71 |
|
72 |
# TODO: scale_base caching is labelled as not yet done in Phi2
|
73 |
"""
|
|
|
96 |
sin: torch.FloatTensor, # dim: (_max_seqlen, d_rotary)
|
97 |
) -> torch.FloatTensor:
|
98 |
seqlen = x.shape[1]
|
99 |
+
x_to_rotate = x[..., :self.d_rotary]
|
100 |
+
x_to_keep_unrotated = x[..., self.d_rotary:]
|
101 |
+
x1, x2 = x_to_rotate.chunk(2, dim=-1) # dim: (batch_size, seqlen, Optional[n_qkv], n_heads, d_rotary/2)
|
102 |
broadcast_rearrange = "s d -> s 1 d" if x1.ndim == 4 else "s d -> s 1 1 d"
|
103 |
c, s = rearrange(cos[:seqlen], broadcast_rearrange), rearrange(sin[:seqlen], broadcast_rearrange)
|
104 |
x1, x2, c, s = [t.to(dtype=torch.float32) for t in [x1, x2, c, s]] # make sure rotary embedding is in float32
|
105 |
+
x_rotated = cast(
|
106 |
torch.FloatTensor,
|
107 |
torch.cat([x1 * c - x2 * s, x1 * s + x2 * c], dim=-1).to(x.dtype)
|
108 |
)
|
109 |
+
return torch.cat([x_rotated, x_to_keep_unrotated], axis=-1)
|
110 |
|
111 |
def forward(
|
112 |
self,
|
|
|
116 |
if (
|
117 |
not self._max_seqlen
|
118 |
or self._max_seqlen < x.shape[1] + seqlen_offset
|
119 |
+
or self._cos_cached.device != x.device
|
120 |
+
or self._cos_cached.dtype != x.dtype
|
121 |
or (self.training and self._cos_cached.is_inference())
|
122 |
):
|
123 |
+
self._update_cos_sin_cache(seqlen=x.shape[1] + seqlen_offset, device=x.device, dtype=x.dtype)
|
124 |
return self._apply_rotary_emb_qkv(
|
125 |
x,
|
126 |
cast(torch.FloatTensor, self._cos_cached[seqlen_offset:]),
|
|
|
280 |
else RotaryEmbedding
|
281 |
)
|
282 |
self.rotary_emb = rotary_cls(
|
283 |
+
# d_rotary=math.ceil((d_embedding // n_attn_heads) / 2), # d_rotary is half of d_head
|
284 |
+
d_rotary=32, # TODO: figure out why Phi2 uses this
|
285 |
initial_cos_sin_cache_len=initial_cos_sin_cache_len,
|
286 |
)
|
287 |
|
|
|
390 |
kv_cache: KVCache,
|
391 |
key_padding_mask: torch.BoolTensor | None,
|
392 |
) -> torch.FloatTensor:
|
393 |
+
qk = qkv[:, :, :2, :, :]
|
394 |
+
qk = self.rotary_emb(
|
395 |
+
qk,
|
396 |
seqlen_offset = 0 if kv_cache is None else kv_cache.seqlen_offset,
|
397 |
)
|
398 |
+
v = cast(torch.FloatTensor, qkv[:, :, 2, :, :])
|
399 |
+
q = qk[:, :, 0, :, :]
|
400 |
+
kv = torch.cat(
|
401 |
+
[
|
402 |
+
qk[:, :, 1, :, :].unsqueeze(2),
|
403 |
+
v.unsqueeze(2),
|
404 |
+
],
|
405 |
+
dim=2,
|
406 |
+
)
|
407 |
self._update_kv_cache(kv, kv_cache, self.block_n)
|
408 |
causal = False # turning off causal mask for cross attention
|
409 |
|