justinpinkney commited on
Commit
795ad25
·
1 Parent(s): c8ca82f

re-enable rotary cache

Browse files
Files changed (1) hide show
  1. modelling_RW.py +12 -16
modelling_RW.py CHANGED
@@ -60,6 +60,7 @@ class RotaryEmbedding(torch.nn.Module):
60
  self.register_buffer("inv_freq", inv_freq, persistent=False)
61
  self.head_dim = head_dim
62
  self.seq_len_cached = None
 
63
  self.batch_size_cached = None
64
  self.cos_cached: torch.Tensor | None = None
65
  self.sin_cached: torch.Tensor | None = None
@@ -71,28 +72,23 @@ class RotaryEmbedding(torch.nn.Module):
71
  dtype=torch.bfloat16,
72
  start_idx: int = 0,
73
  ) -> torch.Tensor:
74
- # if seq_len != self.seq_len_cached:
75
- self.seq_len_cached = seq_len
76
- t = torch.arange(start_idx, start_idx+seq_len, device=device).type_as(self.inv_freq)
77
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
78
- emb = torch.cat((freqs, freqs), dim=-1).to(device)
79
 
80
- if dtype in [torch.float16, torch.bfloat16]:
81
- emb = emb.float()
82
 
83
- self.cos_cached = emb.cos()[None, :, :]
84
- self.sin_cached = emb.sin()[None, :, :]
85
 
86
- self.cos_cached = self.cos_cached.type(dtype)
87
- self.sin_cached = self.sin_cached.type(dtype)
88
 
89
  return self.cos_cached, self.sin_cached
90
 
91
- def forward(self, q, k, start_idx=0):
92
- batch, seq_len, head_dim = q.shape
93
- cos, sin = self.cos_sin(seq_len, q.device, q.dtype, start_idx=start_idx)
94
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
95
-
96
 
97
  def _make_causal_mask(
98
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
 
60
  self.register_buffer("inv_freq", inv_freq, persistent=False)
61
  self.head_dim = head_dim
62
  self.seq_len_cached = None
63
+ self.start_idx = None
64
  self.batch_size_cached = None
65
  self.cos_cached: torch.Tensor | None = None
66
  self.sin_cached: torch.Tensor | None = None
 
72
  dtype=torch.bfloat16,
73
  start_idx: int = 0,
74
  ) -> torch.Tensor:
75
+ if seq_len != self.seq_len_cached and self.start_idx != start_idx:
76
+ self.seq_len_cached = seq_len
77
+ t = torch.arange(start_idx, start_idx+seq_len, device=device).type_as(self.inv_freq)
78
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
79
+ emb = torch.cat((freqs, freqs), dim=-1).to(device)
80
 
81
+ if dtype in [torch.float16, torch.bfloat16]:
82
+ emb = emb.float()
83
 
84
+ self.cos_cached = emb.cos()[None, :, :]
85
+ self.sin_cached = emb.sin()[None, :, :]
86
 
87
+ self.cos_cached = self.cos_cached.type(dtype)
88
+ self.sin_cached = self.sin_cached.type(dtype)
89
 
90
  return self.cos_cached, self.sin_cached
91
 
 
 
 
 
 
92
 
93
  def _make_causal_mask(
94
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int