| import torch | |
| from jaxtyping import Int | |
| from torch import Tensor | |
| def add_third_context_index( | |
| indices: Int[Tensor, "*batch 2"] | |
| ) -> Int[Tensor, "*batch 3"]: | |
| left, right = indices.unbind(dim=-1) | |
| return torch.stack((left, (left + right) // 2, right), dim=-1) | |