Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import numpy as np | |
| from datasets import load_dataset | |
| from transformers import AutoTokenizer | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from enum import Enum | |
| class VisType(Enum): | |
| ALL = 'ALL' | |
| dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train'] | |
| tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m', add_prefix_space=True) | |
| def analyze_sentence(index, vis_type): | |
| row = dataset[index] | |
| text = row['text'] | |
| tokenized = tokenizer.batch_decode(tokenizer.encode(text, add_special_tokens=False)) | |
| attn_map_shape = row['attention_maps_shape'][1:] | |
| seq_len = attn_map_shape[1] | |
| attn_maps = np.array(row['attention_maps']).reshape(*attn_map_shape).clip(0, 1) | |
| fig = plt.figure(figsize=(0.5 + 0.5 * len(tokenized), 0.4 * len(tokenized))) | |
| attn_maps = attn_maps[:, 1:, 1:] | |
| if vis_type == VisType.ALL: | |
| plot_data = attn_maps.sum(0) | |
| sns.heatmap(plot_data) | |
| plt.xticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=90); | |
| plt.yticks(np.arange(seq_len - 1) + 0.5, tokenized[1:], rotation=0); | |
| plt.ylabel('TARGET') | |
| plt.xlabel('SOURCE') | |
| plt.grid() | |
| metrics = {k: v for k, v in record.items() if x not in ['text', 'attention_maps', 'attention_maps_shape']} | |
| return fig | |
| demo = gr.Blocks() | |
| with demo: | |
| with gr.Row(): | |
| sentence_dropdown = gr.Dropdown(label="Sentence", | |
| choices=[x.split('</s> ')[1] for x in dataset['text']], | |
| value=0, min_width=500, type='index') | |
| vis_dropdown = gr.Dropdown(label="Visualization", choices=list(VisType), | |
| min_width=100, value=VisType.ALL, type='value') | |
| btn = gr.Button("Run", min_width=50) | |
| output = gr.Plot(label="Plot", container=True) | |
| metrics = gr.Label("Metrics") | |
| btn.click(analyze_sentence, [sentence_dropdown, vis_dropdown], [output, metrics]) | |
| if __name__ == "__main__": | |
| demo.launch() |