File size: 548 Bytes
17ed7d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def alignment_diagonal_score(alignments, binary=False):
    """

    Compute how diagonal alignment predictions are. It is useful

    to measure the alignment consistency of a model

    Args:

        alignments (torch.Tensor): batch of alignments.

        binary (bool): if True, ignore scores and consider attention

        as a binary mask.

    Shape:

        - alignments : :math:`[B, T_de, T_en]`

    """
    maxs = alignments.max(dim=1)[0]
    if binary:
        maxs[maxs > 0] = 1
    return maxs.mean(dim=1).mean(dim=0).item()