File size: 175 Bytes
2f9282b
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
# -*- coding: utf-8 -*-

import torch


@torch.jit.script
def normalize_output(q, k, o):
    k = k.cumsum(-2)
    z = (q * k).sum(-1, keepdim=True)
    return o / (z + 1e-10)