Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| import os | |
| # Load model and tokenizer | |
| model_name = "gpt2" | |
| model = GPT2LMHeadModel.from_pretrained(model_name) | |
| tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
| def get_token_probabilities(text, top_k=10): | |
| # Tokenize the input text | |
| input_ids = tokenizer.encode(text, return_tensors="pt") | |
| # Get the last token's position | |
| last_token_position = input_ids.shape[1] - 1 | |
| # Get model predictions | |
| with torch.no_grad(): | |
| outputs = model(input_ids) | |
| logits = outputs.logits | |
| # Get probabilities for the next token after the last token | |
| next_token_logits = logits[0, last_token_position, :] | |
| next_token_probs = torch.softmax(next_token_logits, dim=0) | |
| # Get top k most likely tokens | |
| topk_probs, topk_indices = torch.topk(next_token_probs, top_k) | |
| # Convert to numpy for easier handling | |
| topk_probs = topk_probs.numpy() | |
| topk_indices = topk_indices.numpy() | |
| # Decode tokens | |
| topk_tokens = [tokenizer.decode([idx]) for idx in topk_indices] | |
| # Create a plot | |
| plt.figure(figsize=(10, 6)) | |
| sns.barplot(x=topk_probs, y=topk_tokens) | |
| plt.title(f"Top {top_k} token probabilities after: '{text}'") | |
| plt.xlabel("Probability") | |
| plt.ylabel("Tokens") | |
| plt.tight_layout() | |
| # Ensure temp directory exists | |
| os.makedirs("tmp", exist_ok=True) | |
| # Save the plot to a file in the temp directory | |
| plot_path = os.path.join("tmp", "token_probabilities.png") | |
| plt.savefig(plot_path) | |
| plt.close() | |
| return plot_path, dict(zip(topk_tokens, topk_probs.tolist())) | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# GPT-2 Next Token Probability Visualizer") | |
| gr.Markdown("Enter text and see the probabilities of possible next tokens.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Type some text here...", | |
| value="Hello, my name is" | |
| ) | |
| top_k = gr.Slider( | |
| minimum=5, | |
| maximum=20, | |
| value=10, | |
| step=1, | |
| label="Number of top tokens to show" | |
| ) | |
| btn = gr.Button("Generate Probabilities") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Probability Distribution") | |
| output_table = gr.JSON(label="Token Probabilities") | |
| btn.click( | |
| fn=get_token_probabilities, | |
| inputs=[input_text, top_k], | |
| outputs=[output_image, output_table] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["Hello, my name is", 10], | |
| ["The capital of France is", 10], | |
| ["Once upon a time", 10], | |
| ["The best way to learn is to", 10] | |
| ], | |
| inputs=[input_text, top_k], | |
| ) | |
| # Launch the app | |
| demo.launch() |