Commit
·
795ad25
1
Parent(s):
c8ca82f
re-enable rotary cache
Browse files- modelling_RW.py +12 -16
modelling_RW.py
CHANGED
@@ -60,6 +60,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
60 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
61 |
self.head_dim = head_dim
|
62 |
self.seq_len_cached = None
|
|
|
63 |
self.batch_size_cached = None
|
64 |
self.cos_cached: torch.Tensor | None = None
|
65 |
self.sin_cached: torch.Tensor | None = None
|
@@ -71,28 +72,23 @@ class RotaryEmbedding(torch.nn.Module):
|
|
71 |
dtype=torch.bfloat16,
|
72 |
start_idx: int = 0,
|
73 |
) -> torch.Tensor:
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
|
80 |
-
|
81 |
-
|
82 |
|
83 |
-
|
84 |
-
|
85 |
|
86 |
-
|
87 |
-
|
88 |
|
89 |
return self.cos_cached, self.sin_cached
|
90 |
|
91 |
-
def forward(self, q, k, start_idx=0):
|
92 |
-
batch, seq_len, head_dim = q.shape
|
93 |
-
cos, sin = self.cos_sin(seq_len, q.device, q.dtype, start_idx=start_idx)
|
94 |
-
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
95 |
-
|
96 |
|
97 |
def _make_causal_mask(
|
98 |
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|
|
|
60 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
61 |
self.head_dim = head_dim
|
62 |
self.seq_len_cached = None
|
63 |
+
self.start_idx = None
|
64 |
self.batch_size_cached = None
|
65 |
self.cos_cached: torch.Tensor | None = None
|
66 |
self.sin_cached: torch.Tensor | None = None
|
|
|
72 |
dtype=torch.bfloat16,
|
73 |
start_idx: int = 0,
|
74 |
) -> torch.Tensor:
|
75 |
+
if seq_len != self.seq_len_cached and self.start_idx != start_idx:
|
76 |
+
self.seq_len_cached = seq_len
|
77 |
+
t = torch.arange(start_idx, start_idx+seq_len, device=device).type_as(self.inv_freq)
|
78 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
79 |
+
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
80 |
|
81 |
+
if dtype in [torch.float16, torch.bfloat16]:
|
82 |
+
emb = emb.float()
|
83 |
|
84 |
+
self.cos_cached = emb.cos()[None, :, :]
|
85 |
+
self.sin_cached = emb.sin()[None, :, :]
|
86 |
|
87 |
+
self.cos_cached = self.cos_cached.type(dtype)
|
88 |
+
self.sin_cached = self.sin_cached.type(dtype)
|
89 |
|
90 |
return self.cos_cached, self.sin_cached
|
91 |
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
def _make_causal_mask(
|
94 |
input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
|