Gregniuki commited on
Commit
5a638fc
·
verified ·
1 Parent(s): 424a6e9

Update model/modules.py

Browse files
Files changed (1) hide show
  1. model/modules.py +10 -4
model/modules.py CHANGED
@@ -208,13 +208,19 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
208
 
209
 
210
  def get_pos_embed_indices(start, length, max_pos, scale=1.0):
211
- # length = length if isinstance(length, int) else length.max()
212
- scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
 
 
 
 
 
213
  pos = (
214
  start.unsqueeze(1)
215
- + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
216
  )
217
- # avoid extra long error.
 
218
  pos = torch.where(pos < max_pos, pos, max_pos - 1)
219
  return pos
220
 
 
208
 
209
 
210
  def get_pos_embed_indices(start, length, max_pos, scale=1.0):
211
+ """
212
+ Generate positional embedding indices with bfloat16 for computations.
213
+ """
214
+ # Convert scale to a tensor with bfloat16
215
+ scale = scale * torch.ones_like(start, dtype=torch.bfloat16)
216
+
217
+ # Compute positions using bfloat16
218
  pos = (
219
  start.unsqueeze(1)
220
+ + (torch.arange(length, device=start.device, dtype=torch.bfloat16).unsqueeze(0) * scale.unsqueeze(1)).long()
221
  )
222
+
223
+ # Ensure positions do not exceed max_pos
224
  pos = torch.where(pos < max_pos, pos, max_pos - 1)
225
  return pos
226