dar-tau's picture
Update app.py
fac5648 verified
raw
history blame
933 Bytes
import gradio as gr
import numpy as np
from datasets import load_dataset
import matplotlib.pyplot as plt
import seaborn as sns
dataset = load_dataset('dar-tau/grammar-attention-maps-opt-350m')['train']
def analyze_sentence(index):
row = dataset[index]
attn_maps = np.array(row['attention_maps']).reshape(*row['attention_maps_shape'])
plot = sns.heatmap(attn_maps.sum(1).sum(0))
plt.xticks(np.arange(len(tokenized)-1) + 0.5,
tokenizer.tokenize(text, add_special_tokens=False), rotation=90);
plt.yticks(np.arange(len(tokenized)-1) + 0.5,
tokenizer.tokenize(text, add_special_tokens=False), rotation=0);
plt.ylabel('TARGET')
plt.xlabel('SOURCE')
plt.grid()
return row['text'], plot
iface = gr.Interface(fn=analyze_sentence, inputs=[gr.Dropdown(choices=dataset['text'], type='index')],
outputs=[gr.Label(), gr.Plot(label="Plot")])
iface.launch()