File size: 3,805 Bytes
fe1089d
 
 
 
 
 
 
 
 
d2116db
fe1089d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d2116db
 
fe1089d
 
d2116db
 
fe1089d
d2116db
fe1089d
 
69b34c4
 
 
d2116db
 
69b34c4
fe1089d
69b34c4
fe1089d
 
 
 
69b34c4
fe1089d
 
 
 
 
69b34c4
fe1089d
 
 
 
 
 
 
 
 
 
 
 
 
69b34c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fe1089d
d2116db
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# visualization module that creates an attention visualization using BERTViz

# external imports
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# internal imports
from utils import formatting as fmt
from .markup import markup_text


# plotting function that plots the attention values in a heatmap
def chat_explained(model, prompt):

    model.set_config()

    # get encoded input and output vectors
    encoder_input_ids = model.TOKENIZER(
        prompt, return_tensors="pt", add_special_tokens=True
    ).input_ids
    decoder_input_ids = model.MODEL.generate(encoder_input_ids, output_attentions=True)
    encoder_text = fmt.format_tokens(
        model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
    )
    decoder_text = fmt.format_tokens(
        model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0])
    )

    # get attention values for the input and output vectors
    attention_output = model.MODEL(
        input_ids=encoder_input_ids,
        decoder_input_ids=decoder_input_ids,
        output_attentions=True,
    )

    averaged_attention = avg_attention(attention_output)

    # create the response text, graphic and plot
    response_text = fmt.format_output_text(decoder_text)
    plot = create_plot(averaged_attention, (encoder_text, decoder_text))
    marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")

    return response_text, "", plot, marked_text


# creating an attention heatmap plot using matplotlib/seaborn
# CREDIT: adopted from official Matplotlib documentation
## see https://matplotlib.org/stable/
def create_plot(averaged_attention_weights, enc_dec_texts: tuple):
    # transpose the attention weights
    averaged_attention_weights = np.transpose(averaged_attention_weights)

    # get the encoder and decoder tokens in text form
    encoder_tokens = enc_dec_texts[0]
    decoder_tokens = enc_dec_texts[1]

    # set seaborn style to dark and initialize figure and axis
    sns.set(style="white")
    fig, ax = plt.subplots()

    # Setting figure size
    fig.set_size_inches(
        max(averaged_attention_weights.shape[1] * 2, 10),
        max(averaged_attention_weights.shape[0] * 1, 5),
    )

    # Plotting the heatmap with seaborn's color palette
    im = ax.imshow(
        averaged_attention_weights,
        vmax=averaged_attention_weights.max(),
        vmin=-averaged_attention_weights.min(),
        cmap=sns.color_palette("rocket", as_cmap=True),
        aspect="auto",
    )

    # Creating colorbar
    cbar = ax.figure.colorbar(im, ax=ax)
    cbar.ax.set_ylabel("Attention Weight Scale", rotation=-90, va="bottom")
    cbar.ax.yaxis.set_tick_params(color="black")
    plt.setp(plt.getp(cbar.ax.axes, "yticklabels"), color="black")

    # Setting ticks and labels with black color for visibility
    ax.set_yticks(np.arange(len(encoder_tokens)), labels=encoder_tokens)
    ax.set_xticks(np.arange(len(decoder_tokens)), labels=decoder_tokens)
    ax.set_title("Attention Weights by Token")
    plt.setp(ax.get_xticklabels(), color="black", rotation=45, ha="right")
    plt.setp(ax.get_yticklabels(), color="black")

    # Adding text annotations with appropriate contrast
    for i in range(averaged_attention_weights.shape[0]):
        for j in range(averaged_attention_weights.shape[1]):
            val = averaged_attention_weights[i, j]
            color = (
                "white"
                if im.norm(averaged_attention_weights.max()) / 2 > im.norm(val)
                else "black"
            )
            ax.text(j, i, f"{val:.4f}", ha="center", va="center", color=color)

    # return the plot
    return plt


def avg_attention(attention_values):
    attention = attention_values.cross_attentions[0][0].detach().numpy()
    return np.mean(attention, axis=0)