File size: 1,023 Bytes
69524d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import matplotlib.pyplot as plt
import numpy as np

def plot_mean_att_distance(mean_att_dist):
    'mean_att_dist shape: (num_layers, num_heads)'
    num_layers = mean_att_dist.shape[0]
    num_heads = mean_att_dist.shape[1]
    # Create the plot
    plt.figure(figsize=(10, 6))

    for head in range(num_heads):
        values = mean_att_dist[:, head]
        plt.scatter(range(num_layers), values, label=f'Head {head}', s=20)

    plt.xlabel('Network depth (layer)')
    plt.ylabel('Mean attention distance (pixels)')
    plt.xlim(0, num_layers - 1)
    plt.ylim(0, 128)

    # Customize legend
    plt.legend(loc='lower right', ncol=2, fontsize='small')

    # Add ellipsis to legend
    handles, labels = plt.gca().get_legend_handles_labels()
    handles.append(plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='gray', markersize=5))
    labels.append('...')
    plt.legend(handles, labels, loc='lower right', ncol=2, fontsize='small')
    plt.tight_layout()
    
    return plt