File size: 1,945 Bytes
fe1089d
 
 
7b73dfd
 
 
 
fe1089d
 
d2116db
fe1089d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5ebee7
d2116db
f5ebee7
fe1089d
7b73dfd
d2116db
fe1089d
7b73dfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# visualization module that creates an attention visualization using BERTViz


# external imports
from bertviz import neuron_view as nv


# internal imports
from utils import formatting as fmt
from .markup import markup_text


# plotting function that plots the attention values in a heatmap
def chat_explained(model, prompt):

    model.set_config()

    # get encoded input and output vectors
    encoder_input_ids = model.TOKENIZER(
        prompt, return_tensors="pt", add_special_tokens=True
    ).input_ids
    decoder_input_ids = model.MODEL.generate(encoder_input_ids, output_attentions=True)
    encoder_text = fmt.format_tokens(
        model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
    )
    decoder_text = fmt.format_tokens(
        model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0])
    )

    # get attention values for the input and output vectors
    attention_output = model.MODEL(
        input_ids=encoder_input_ids,
        decoder_input_ids=decoder_input_ids,
        output_attentions=True,
    )

    averaged_attention = fmt.avg_attention(attention_output)

    # create the response text and marked text for ui
    response_text = fmt.format_output_text(decoder_text)
    xai_graphic = attention_graphic(encoder_text, decoder_text, model)
    marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")

    return response_text, xai_graphic, marked_text


def attention_graphic(encoder_text, decoder_text, model):

    # set model type to BERT (to fake out BERTViz)
    model_type = "bert"

    # create sentence a and b from list of strings
    sentence_a = " ".join(encoder_text)
    sentence_b = " ".join(decoder_text)

    # display neuron view
    return nv.show(
        model.MODEL,
        model_type,
        model.TOKENIZER,
        sentence_a,
        sentence_b,
        display_mode="light",
        layer=2,
        head=0,
        html_action="return",
    )