SpiridonSunRotator commited on
Commit
239b4f8
·
1 Parent(s): 89d7ce1

Update modelling_RW.py

Browse files

Fixed duplicate forward

Files changed (1) hide show
  1. modelling_RW.py +0 -5
modelling_RW.py CHANGED
@@ -92,11 +92,6 @@ class RotaryEmbedding(torch.nn.Module):
92
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
  return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
94
 
95
- def forward(self, q, k):
96
- batch, seq_len, head_dim = q.shape
97
- cos, sin = self.cos_sin(seq_len, q.device)
98
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
99
-
100
 
101
  def _make_causal_mask(
102
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
 
92
  cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
  return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
94
 
 
 
 
 
 
95
 
96
  def _make_causal_mask(
97
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int