Spaces:
Runtime error
Runtime error
# visualization module that creates an attention visualization using BERTViz | |
# external imports | |
from bertviz import neuron_view as nv | |
# internal imports | |
from utils import formatting as fmt | |
from .markup import markup_text | |
# plotting function that plots the attention values in a heatmap | |
def chat_explained(model, prompt): | |
model.set_config() | |
# get encoded input and output vectors | |
encoder_input_ids = model.TOKENIZER( | |
prompt, return_tensors="pt", add_special_tokens=True | |
).input_ids | |
decoder_input_ids = model.MODEL.generate(encoder_input_ids, output_attentions=True) | |
encoder_text = fmt.format_tokens( | |
model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0]) | |
) | |
decoder_text = fmt.format_tokens( | |
model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0]) | |
) | |
# get attention values for the input and output vectors | |
attention_output = model.MODEL( | |
input_ids=encoder_input_ids, | |
decoder_input_ids=decoder_input_ids, | |
output_attentions=True, | |
) | |
averaged_attention = fmt.avg_attention(attention_output) | |
# create the response text and marked text for ui | |
response_text = fmt.format_output_text(decoder_text) | |
xai_graphic = attention_graphic(encoder_text, decoder_text, model) | |
marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer") | |
return response_text, xai_graphic, marked_text | |
def attention_graphic(encoder_text, decoder_text, model): | |
# set model type to BERT (to fake out BERTViz) | |
model_type = "bert" | |
# create sentence a and b from list of strings | |
sentence_a = " ".join(encoder_text) | |
sentence_b = " ".join(decoder_text) | |
# display neuron view | |
return nv.show( | |
model.MODEL, | |
model_type, | |
model.TOKENIZER, | |
sentence_a, | |
sentence_b, | |
display_mode="light", | |
layer=2, | |
head=0, | |
html_action="return", | |
) | |