Spaces:
Sleeping
Sleeping
Update model/modules.py
Browse files- model/modules.py +10 -4
model/modules.py
CHANGED
|
@@ -208,13 +208,19 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_resca
|
|
| 208 |
|
| 209 |
|
| 210 |
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
| 211 |
-
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
pos = (
|
| 214 |
start.unsqueeze(1)
|
| 215 |
-
+ (torch.arange(length, device=start.device, dtype=torch.
|
| 216 |
)
|
| 217 |
-
|
|
|
|
| 218 |
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
| 219 |
return pos
|
| 220 |
|
|
|
|
| 208 |
|
| 209 |
|
| 210 |
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
| 211 |
+
"""
|
| 212 |
+
Generate positional embedding indices with bfloat16 for computations.
|
| 213 |
+
"""
|
| 214 |
+
# Convert scale to a tensor with bfloat16
|
| 215 |
+
scale = scale * torch.ones_like(start, dtype=torch.bfloat16)
|
| 216 |
+
|
| 217 |
+
# Compute positions using bfloat16
|
| 218 |
pos = (
|
| 219 |
start.unsqueeze(1)
|
| 220 |
+
+ (torch.arange(length, device=start.device, dtype=torch.bfloat16).unsqueeze(0) * scale.unsqueeze(1)).long()
|
| 221 |
)
|
| 222 |
+
|
| 223 |
+
# Ensure positions do not exceed max_pos
|
| 224 |
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
| 225 |
return pos
|
| 226 |
|