File size: 175 Bytes
2f9282b |
1 2 3 4 5 6 7 8 9 10 11 |
# -*- coding: utf-8 -*-
import torch
@torch.jit.script
def normalize_output(q, k, o):
k = k.cumsum(-2)
z = (q * k).sum(-1, keepdim=True)
return o / (z + 1e-10)
|