Spaces:
Running
on
Zero
Running
on
Zero
Update model/modules.py
Browse files- model/modules.py +2 -2
model/modules.py
CHANGED
@@ -212,12 +212,12 @@ def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
|
212 |
Generate positional embedding indices with bfloat16 for computations.
|
213 |
"""
|
214 |
# Convert scale to a tensor with bfloat16
|
215 |
-
scale = scale * torch.ones_like(start, dtype=torch.
|
216 |
|
217 |
# Compute positions using bfloat16
|
218 |
pos = (
|
219 |
start.unsqueeze(1)
|
220 |
-
+ (torch.arange(length, device=start.device, dtype=torch.
|
221 |
)
|
222 |
|
223 |
# Ensure positions do not exceed max_pos
|
|
|
212 |
Generate positional embedding indices with bfloat16 for computations.
|
213 |
"""
|
214 |
# Convert scale to a tensor with bfloat16
|
215 |
+
scale = scale * torch.ones_like(start, dtype=torch.float32)
|
216 |
|
217 |
# Compute positions using bfloat16
|
218 |
pos = (
|
219 |
start.unsqueeze(1)
|
220 |
+
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
|
221 |
)
|
222 |
|
223 |
# Ensure positions do not exceed max_pos
|