Gregniuki commited on
Commit
b8383b8
·
verified ·
1 Parent(s): a591b69

Update model/modules.py

Browse files
Files changed (1) hide show
  1. model/modules.py +2 -2
model/modules.py CHANGED
@@ -212,12 +212,12 @@ def get_pos_embed_indices(start, length, max_pos, scale=1.0):
212
  Generate positional embedding indices with bfloat16 for computations.
213
  """
214
  # Convert scale to a tensor with bfloat16
215
- scale = scale * torch.ones_like(start, dtype=torch.bfloat16)
216
 
217
  # Compute positions using bfloat16
218
  pos = (
219
  start.unsqueeze(1)
220
- + (torch.arange(length, device=start.device, dtype=torch.bfloat16).unsqueeze(0) * scale.unsqueeze(1)).long()
221
  )
222
 
223
  # Ensure positions do not exceed max_pos
 
212
  Generate positional embedding indices with bfloat16 for computations.
213
  """
214
  # Convert scale to a tensor with bfloat16
215
+ scale = scale * torch.ones_like(start, dtype=torch.float32)
216
 
217
  # Compute positions using bfloat16
218
  pos = (
219
  start.unsqueeze(1)
220
+ + (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
221
  )
222
 
223
  # Ensure positions do not exceed max_pos