Spaces:
Sleeping
Sleeping
SaiMupparaju
commited on
Commit
·
03653db
0
Parent(s):
Initial commit for MechVis Hugging Face Space
Browse files- .gitignore +26 -0
- 1_4_1_Indirect_Object_Identification_exercises.ipynb +0 -0
- Dockerfile +13 -0
- Procfile +1 -0
- README.md +80 -0
- README_HF.md +27 -0
- app.py +116 -0
- requirements.txt +5 -0
- space.yaml +7 -0
- templates/index.html +346 -0
.gitignore
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# Distribution / packaging
|
7 |
+
dist/
|
8 |
+
build/
|
9 |
+
*.egg-info/
|
10 |
+
|
11 |
+
# Virtual environments
|
12 |
+
venv/
|
13 |
+
.env/
|
14 |
+
|
15 |
+
# Environment variables
|
16 |
+
.env
|
17 |
+
|
18 |
+
# IDE files
|
19 |
+
.vscode/
|
20 |
+
.idea/
|
21 |
+
|
22 |
+
# Jupyter Notebook
|
23 |
+
.ipynb_checkpoints/
|
24 |
+
|
25 |
+
# Miscellaneous
|
26 |
+
.DS_Store
|
1_4_1_Indirect_Object_Identification_exercises.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
Dockerfile
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9
|
2 |
+
|
3 |
+
WORKDIR /code
|
4 |
+
|
5 |
+
COPY ./requirements.txt /code/requirements.txt
|
6 |
+
|
7 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
8 |
+
|
9 |
+
COPY . /code
|
10 |
+
|
11 |
+
EXPOSE 7860
|
12 |
+
|
13 |
+
CMD ["python", "app.py"]
|
Procfile
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
web: python app.py
|
README.md
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MechVis: GPT-2 Attention Head Contribution Visualization
|
2 |
+
|
3 |
+
[](https://huggingface.co/spaces/saivamsim26/mechvis)
|
4 |
+
|
5 |
+
MechVis is a tool for visualizing how attention heads in GPT-2 small contribute to next token predictions. It provides a simple web interface where you can enter text, see what token the model predicts next, and visualize which attention heads contribute most to that prediction.
|
6 |
+
|
7 |
+
This project is inspired by mechanistic interpretability research on language models, particularly studies of "indirect object identification" in GPT-2 small.
|
8 |
+
|
9 |
+
## Features
|
10 |
+
|
11 |
+
- Input any text prompt and see GPT-2's next token prediction
|
12 |
+
- View a heatmap visualization of each attention head's contribution to the predicted token
|
13 |
+
- Interactive tooltips showing exact contribution values for each head
|
14 |
+
- Simple, clean web interface
|
15 |
+
|
16 |
+
## Deployment on Hugging Face Spaces
|
17 |
+
|
18 |
+
1. Create a new Space on Hugging Face:
|
19 |
+
- Go to https://huggingface.co/spaces
|
20 |
+
- Click "Create new Space"
|
21 |
+
- Choose "Docker" as the SDK
|
22 |
+
- Set the environment variables if needed
|
23 |
+
|
24 |
+
2. Upload the following files to your Space:
|
25 |
+
- `app.py`
|
26 |
+
- `requirements.txt`
|
27 |
+
- `Dockerfile`
|
28 |
+
- Contents of `templates/` directory
|
29 |
+
|
30 |
+
The application will automatically deploy and will be available at your Space's URL.
|
31 |
+
|
32 |
+
## Local Development
|
33 |
+
|
34 |
+
To run the application locally:
|
35 |
+
|
36 |
+
```bash
|
37 |
+
pip install -r requirements.txt
|
38 |
+
python app.py
|
39 |
+
```
|
40 |
+
|
41 |
+
The application will be available at http://localhost:7860
|
42 |
+
|
43 |
+
## How to Use
|
44 |
+
|
45 |
+
1. Enter a text prompt in the input field
|
46 |
+
2. Click "Predict Next Word"
|
47 |
+
3. View the predicted token, its logit value, and probability
|
48 |
+
4. Explore the heatmap visualization showing each attention head's contribution:
|
49 |
+
- Red cells indicate positive contributions to the predicted token
|
50 |
+
- Blue cells indicate negative contributions
|
51 |
+
- Hover over cells to see exact contribution values
|
52 |
+
|
53 |
+
## Understanding the Visualization
|
54 |
+
|
55 |
+
The visualization shows a 12×12 grid representing all attention heads in GPT-2 small, with:
|
56 |
+
- Rows representing layers (0-11)
|
57 |
+
- Columns representing heads within each layer (0-11)
|
58 |
+
- Color intensity showing the magnitude of contribution
|
59 |
+
|
60 |
+
This kind of visualization can help identify which attention heads are most important for specific prediction tasks. For example, research has shown that certain heads specialize in tasks like:
|
61 |
+
- Name mover heads (e.g., 9.9, 10.0, 9.6)
|
62 |
+
- Induction heads (e.g., 5.5, 6.9)
|
63 |
+
- S-inhibition heads (e.g., 7.3, 7.9, 8.6, 8.10)
|
64 |
+
|
65 |
+
## Example Use Cases
|
66 |
+
|
67 |
+
1. **Indirect Object Identification**: Try entering "When John and Mary went to the store, John gave a drink to" and see which heads contribute to predicting "Mary"
|
68 |
+
|
69 |
+
2. **Induction Pattern Detection**: Enter repetitive sequences like "The capital of France is Paris. The capital of Germany is" to see induction heads activate
|
70 |
+
|
71 |
+
3. **Exploration**: Try various prompts to see how different heads specialize in different linguistic patterns
|
72 |
+
|
73 |
+
## References
|
74 |
+
|
75 |
+
- [Transformer Lens](https://github.com/neelnanda-io/TransformerLens) - Library for transformer interpretability
|
76 |
+
- [Indirect Object Identification](https://arxiv.org/abs/2211.00593) - Research on circuits in GPT-2 small
|
77 |
+
|
78 |
+
## License
|
79 |
+
|
80 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
README_HF.md
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MechVis: GPT-2 Attention Head Visualization
|
2 |
+
|
3 |
+
This interactive web app allows you to visualize how different attention heads in GPT-2 small contribute to next token predictions.
|
4 |
+
|
5 |
+
## How to Use
|
6 |
+
|
7 |
+
1. Enter text in the input field (e.g., "When John and Mary went to the store, John gave a drink to")
|
8 |
+
2. Click "Predict Next Word"
|
9 |
+
3. See what token GPT-2 predicts next and explore how each attention head contributes to that prediction
|
10 |
+
|
11 |
+
## Features
|
12 |
+
|
13 |
+
- Next token prediction with GPT-2 small
|
14 |
+
- Interactive heatmap showing attention head contributions
|
15 |
+
- Layer contribution analysis
|
16 |
+
- Hover over cells to see exact contribution values
|
17 |
+
|
18 |
+
## Examples to Try
|
19 |
+
|
20 |
+
- **Indirect Object Identification**: "When John and Mary went to the store, John gave a drink to" (likely predicts "Mary")
|
21 |
+
- **Induction Pattern**: "The capital of France is Paris. The capital of Germany is" (likely predicts "Berlin")
|
22 |
+
|
23 |
+
## About
|
24 |
+
|
25 |
+
This project uses [TransformerLens](https://github.com/neelnanda-io/TransformerLens) to access internal model activations and calculate how each attention head contributes to the final logit score of the predicted token.
|
26 |
+
|
27 |
+
[GitHub Repository](https://github.com/saivamsim26/mechvis)
|
app.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from flask import Flask, render_template, request, jsonify
|
4 |
+
from transformer_lens import HookedTransformer
|
5 |
+
import json
|
6 |
+
|
7 |
+
app = Flask(__name__)
|
8 |
+
|
9 |
+
# Load GPT-2 small model
|
10 |
+
model = HookedTransformer.from_pretrained(
|
11 |
+
"gpt2-small",
|
12 |
+
center_unembed=True,
|
13 |
+
center_writing_weights=True,
|
14 |
+
fold_ln=True,
|
15 |
+
refactor_factored_attn_matrices=True,
|
16 |
+
)
|
17 |
+
|
18 |
+
@app.route('/', methods=['GET', 'POST'])
|
19 |
+
def index():
|
20 |
+
prediction = None
|
21 |
+
text = ""
|
22 |
+
head_contributions = None
|
23 |
+
|
24 |
+
if request.method == 'POST':
|
25 |
+
text = request.form.get('text', '')
|
26 |
+
|
27 |
+
if text:
|
28 |
+
# Tokenize the input text
|
29 |
+
tokens = model.to_tokens(text, prepend_bos=True)
|
30 |
+
|
31 |
+
# Run the model with cache to get intermediate activations
|
32 |
+
logits, cache = model.run_with_cache(tokens)
|
33 |
+
|
34 |
+
# Get logits for the last token
|
35 |
+
last_token_logits = logits[0, -1]
|
36 |
+
|
37 |
+
# Get the index of the token with the highest logit
|
38 |
+
top_token_idx = torch.argmax(last_token_logits).item()
|
39 |
+
|
40 |
+
# Get the logit value
|
41 |
+
top_token_logit = last_token_logits[top_token_idx].item()
|
42 |
+
|
43 |
+
# Get the probability
|
44 |
+
probs = torch.nn.functional.softmax(last_token_logits, dim=-1)
|
45 |
+
top_token_prob = probs[top_token_idx].item() * 100 # Convert to percentage
|
46 |
+
|
47 |
+
# Get the token as a string
|
48 |
+
top_token_str = model.to_string([top_token_idx])
|
49 |
+
|
50 |
+
# Get attention head contributions for the top token
|
51 |
+
head_contributions = calculate_head_contributions(cache, top_token_idx, model)
|
52 |
+
|
53 |
+
prediction = {
|
54 |
+
'token': top_token_str,
|
55 |
+
'logit': top_token_logit,
|
56 |
+
'prob': top_token_prob
|
57 |
+
}
|
58 |
+
|
59 |
+
return render_template('index.html', prediction=prediction, text=text, head_contributions=json.dumps(head_contributions) if head_contributions else None)
|
60 |
+
|
61 |
+
def calculate_head_contributions(cache, token_idx, model):
|
62 |
+
"""Calculate the contribution of each attention head to the top token's logit."""
|
63 |
+
|
64 |
+
# Get all head outputs for the last token
|
65 |
+
head_outputs_by_layer = []
|
66 |
+
contributions = []
|
67 |
+
layer_total_contributions = []
|
68 |
+
|
69 |
+
# Get the direction in the residual stream that corresponds to the token
|
70 |
+
token_direction = model.W_U[:, token_idx].detach()
|
71 |
+
|
72 |
+
# Calculate contributions for each head
|
73 |
+
for layer in range(model.cfg.n_layers):
|
74 |
+
# Get the output of each head at the last position
|
75 |
+
z = cache["z", layer][0, -1] # [head, d_head]
|
76 |
+
|
77 |
+
# Apply the OV matrix for each head
|
78 |
+
head_outputs = torch.einsum("hd,hdm->hm", z, model.W_O[layer]) # [head, d_model]
|
79 |
+
|
80 |
+
# Project onto the token direction to get contribution to the logit
|
81 |
+
head_contribs = torch.einsum("hm,m->h", head_outputs, token_direction)
|
82 |
+
|
83 |
+
# Calculate total contribution for this layer
|
84 |
+
layer_total = head_contribs.sum().item()
|
85 |
+
layer_total_contributions.append(layer_total)
|
86 |
+
|
87 |
+
# Convert to list for JSON serialization
|
88 |
+
layer_contributions = head_contribs.detach().cpu().numpy().tolist()
|
89 |
+
contributions.append(layer_contributions)
|
90 |
+
|
91 |
+
# Calculate total contribution across all heads
|
92 |
+
total_contribution = sum([sum(layer_contrib) for layer_contrib in contributions])
|
93 |
+
|
94 |
+
# Convert contributions to percentage of total
|
95 |
+
percentage_contributions = []
|
96 |
+
for layer_contributions in contributions:
|
97 |
+
percentage_layer = [(contrib / total_contribution) * 100 for contrib in layer_contributions]
|
98 |
+
percentage_contributions.append(percentage_layer)
|
99 |
+
|
100 |
+
# Calculate per-layer contribution percentages
|
101 |
+
layer_percentages = [(layer_total / total_contribution) * 100 for layer_total in layer_total_contributions]
|
102 |
+
|
103 |
+
# Get the max and min values for normalization in visualization
|
104 |
+
all_contribs_pct = np.array(percentage_contributions).flatten()
|
105 |
+
max_contrib = float(np.max(all_contribs_pct))
|
106 |
+
min_contrib = float(np.min(all_contribs_pct))
|
107 |
+
|
108 |
+
return {
|
109 |
+
"contributions": percentage_contributions,
|
110 |
+
"max_value": max_contrib,
|
111 |
+
"min_value": min_contrib,
|
112 |
+
"layer_contributions": layer_percentages
|
113 |
+
}
|
114 |
+
|
115 |
+
if __name__ == '__main__':
|
116 |
+
app.run(host="0.0.0.0", port=7860, debug=False)
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
flask==2.0.1
|
2 |
+
torch==2.0.1
|
3 |
+
numpy>=1.21.0
|
4 |
+
transformer-lens==1.2.2
|
5 |
+
gunicorn==20.1.0
|
space.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
title: MechVis
|
2 |
+
emoji: 📊
|
3 |
+
colorFrom: indigo
|
4 |
+
colorTo: purple
|
5 |
+
sdk: docker
|
6 |
+
app_port: 7860
|
7 |
+
pinned: false
|
templates/index.html
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>GPT-2 Next Word Prediction</title>
|
7 |
+
<link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet">
|
8 |
+
<script src="https://d3js.org/d3.v7.min.js"></script>
|
9 |
+
<style>
|
10 |
+
body {
|
11 |
+
padding: 40px;
|
12 |
+
font-family: system-ui, -apple-system, sans-serif;
|
13 |
+
}
|
14 |
+
.prediction {
|
15 |
+
margin-top: 30px;
|
16 |
+
padding: 20px;
|
17 |
+
background-color: #f8f9fa;
|
18 |
+
border-radius: 5px;
|
19 |
+
}
|
20 |
+
.token {
|
21 |
+
font-size: 1.2rem;
|
22 |
+
font-weight: bold;
|
23 |
+
background-color: #e9ecef;
|
24 |
+
padding: 5px 10px;
|
25 |
+
border-radius: 4px;
|
26 |
+
display: inline-block;
|
27 |
+
margin-bottom: 10px;
|
28 |
+
}
|
29 |
+
#visualization {
|
30 |
+
margin-top: 30px;
|
31 |
+
width: 100%;
|
32 |
+
overflow-x: auto;
|
33 |
+
}
|
34 |
+
.head-cell {
|
35 |
+
stroke: #ddd;
|
36 |
+
stroke-width: 1px;
|
37 |
+
}
|
38 |
+
.layer-label, .head-label {
|
39 |
+
font-size: 12px;
|
40 |
+
font-weight: bold;
|
41 |
+
text-anchor: middle;
|
42 |
+
}
|
43 |
+
.tooltip {
|
44 |
+
position: absolute;
|
45 |
+
background-color: rgba(255, 255, 255, 0.9);
|
46 |
+
border: 1px solid #ddd;
|
47 |
+
padding: 8px;
|
48 |
+
border-radius: 4px;
|
49 |
+
pointer-events: none;
|
50 |
+
font-size: 12px;
|
51 |
+
}
|
52 |
+
.visualization-container {
|
53 |
+
margin-top: 30px;
|
54 |
+
background-color: white;
|
55 |
+
border-radius: 5px;
|
56 |
+
padding: 20px;
|
57 |
+
box-shadow: 0 0 10px rgba(0,0,0,0.1);
|
58 |
+
}
|
59 |
+
.legend {
|
60 |
+
margin-top: 15px;
|
61 |
+
margin-bottom: 20px;
|
62 |
+
}
|
63 |
+
.legend-item {
|
64 |
+
display: inline-block;
|
65 |
+
margin-right: 20px;
|
66 |
+
}
|
67 |
+
.legend-color {
|
68 |
+
display: inline-block;
|
69 |
+
width: 20px;
|
70 |
+
height: 20px;
|
71 |
+
margin-right: 5px;
|
72 |
+
vertical-align: middle;
|
73 |
+
}
|
74 |
+
</style>
|
75 |
+
</head>
|
76 |
+
<body>
|
77 |
+
<div class="container">
|
78 |
+
<h1 class="mb-4">GPT-2 Next Word Prediction</h1>
|
79 |
+
|
80 |
+
<div class="row">
|
81 |
+
<div class="col-md-12">
|
82 |
+
<form method="POST">
|
83 |
+
<div class="mb-3">
|
84 |
+
<label for="text" class="form-label">Input Text:</label>
|
85 |
+
<textarea class="form-control" id="text" name="text" rows="3" placeholder="Enter text (e.g. 'When John and Mary went to the store, John gave a drink to')" required>{{ text }}</textarea>
|
86 |
+
</div>
|
87 |
+
<button type="submit" class="btn btn-primary">Predict Next Word</button>
|
88 |
+
</form>
|
89 |
+
</div>
|
90 |
+
</div>
|
91 |
+
|
92 |
+
{% if prediction %}
|
93 |
+
<div class="row">
|
94 |
+
<div class="col-md-12">
|
95 |
+
<div class="prediction">
|
96 |
+
<h3>Prediction Results</h3>
|
97 |
+
<p>Input text: <strong>{{ text }}</strong></p>
|
98 |
+
<p>Next word: <span class="token">{{ prediction.token }}</span></p>
|
99 |
+
<p>Logit value: <strong>{{ "%.4f"|format(prediction.logit) }}</strong></p>
|
100 |
+
<p>Probability: <strong>{{ "%.2f"|format(prediction.prob) }}%</strong></p>
|
101 |
+
</div>
|
102 |
+
</div>
|
103 |
+
</div>
|
104 |
+
|
105 |
+
{% if head_contributions %}
|
106 |
+
<div class="row">
|
107 |
+
<div class="col-md-12">
|
108 |
+
<div class="visualization-container">
|
109 |
+
<h3>Layer Contributions to Log Probability</h3>
|
110 |
+
<p>This chart shows how each layer in GPT-2 contributes to the log probability of the token "{{ prediction.token }}" (as % of total contribution).</p>
|
111 |
+
|
112 |
+
<div id="layer-chart"></div>
|
113 |
+
|
114 |
+
<h3>Attention Head Contributions</h3>
|
115 |
+
<p>This visualization shows how each attention head in GPT-2 contributes to the prediction of the token "{{ prediction.token }}" (as % of total contribution).</p>
|
116 |
+
|
117 |
+
<div class="legend">
|
118 |
+
<div class="legend-item">
|
119 |
+
<div class="legend-color" style="background-color: #4575b4;"></div>
|
120 |
+
<span>Negative contribution %</span>
|
121 |
+
</div>
|
122 |
+
<div class="legend-item">
|
123 |
+
<div class="legend-color" style="background-color: #ffffbf;"></div>
|
124 |
+
<span>Neutral (0%)</span>
|
125 |
+
</div>
|
126 |
+
<div class="legend-item">
|
127 |
+
<div class="legend-color" style="background-color: #d73027;"></div>
|
128 |
+
<span>Positive contribution %</span>
|
129 |
+
</div>
|
130 |
+
</div>
|
131 |
+
|
132 |
+
<div id="visualization"></div>
|
133 |
+
</div>
|
134 |
+
</div>
|
135 |
+
</div>
|
136 |
+
{% endif %}
|
137 |
+
{% endif %}
|
138 |
+
</div>
|
139 |
+
|
140 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"></script>
|
141 |
+
|
142 |
+
{% if head_contributions %}
|
143 |
+
<script>
|
144 |
+
document.addEventListener('DOMContentLoaded', function() {
|
145 |
+
const headContributions = {{ head_contributions|safe }};
|
146 |
+
|
147 |
+
// Create layer contributions bar chart
|
148 |
+
const createLayerChart = () => {
|
149 |
+
const layerContribs = headContributions.layer_contributions;
|
150 |
+
const margin = { top: 40, right: 30, bottom: 50, left: 60 };
|
151 |
+
const width = Math.min(800, window.innerWidth - 100);
|
152 |
+
const height = 300;
|
153 |
+
|
154 |
+
const svg = d3.select("#layer-chart")
|
155 |
+
.append("svg")
|
156 |
+
.attr("width", width)
|
157 |
+
.attr("height", height);
|
158 |
+
|
159 |
+
const g = svg.append("g")
|
160 |
+
.attr("transform", `translate(${margin.left},${margin.top})`);
|
161 |
+
|
162 |
+
// Create scales
|
163 |
+
const x = d3.scaleBand()
|
164 |
+
.domain(d3.range(layerContribs.length))
|
165 |
+
.range([0, width - margin.left - margin.right])
|
166 |
+
.padding(0.1);
|
167 |
+
|
168 |
+
const y = d3.scaleLinear()
|
169 |
+
.domain([
|
170 |
+
Math.min(0, d3.min(layerContribs)),
|
171 |
+
Math.max(0, d3.max(layerContribs))
|
172 |
+
])
|
173 |
+
.nice()
|
174 |
+
.range([height - margin.top - margin.bottom, 0]);
|
175 |
+
|
176 |
+
// Create color scale - positive is green, negative is purple
|
177 |
+
const colorScale = d3.scaleLinear()
|
178 |
+
.domain([Math.min(0, d3.min(layerContribs)), 0, Math.max(0, d3.max(layerContribs))])
|
179 |
+
.range(["#9467bd", "#f7f7f7", "#2ca02c"]);
|
180 |
+
|
181 |
+
// Create tooltip
|
182 |
+
const tooltip = d3.select("body")
|
183 |
+
.append("div")
|
184 |
+
.attr("class", "tooltip")
|
185 |
+
.style("opacity", 0);
|
186 |
+
|
187 |
+
// Create bars
|
188 |
+
g.selectAll(".bar")
|
189 |
+
.data(layerContribs)
|
190 |
+
.join("rect")
|
191 |
+
.attr("class", "bar")
|
192 |
+
.attr("x", (d, i) => x(i))
|
193 |
+
.attr("y", d => d >= 0 ? y(d) : y(0))
|
194 |
+
.attr("width", x.bandwidth())
|
195 |
+
.attr("height", d => Math.abs(y(0) - y(d)))
|
196 |
+
.attr("fill", d => colorScale(d))
|
197 |
+
.attr("stroke", "#555")
|
198 |
+
.attr("stroke-width", 1)
|
199 |
+
.on("mouseover", function(event, d) {
|
200 |
+
d3.select(this).attr("stroke", "#000").attr("stroke-width", 2);
|
201 |
+
tooltip.transition().duration(200).style("opacity", 1);
|
202 |
+
tooltip.html(`Layer ${layerContribs.indexOf(d)}<br>Contribution: ${d.toFixed(2)}%`)
|
203 |
+
.style("left", (event.pageX + 10) + "px")
|
204 |
+
.style("top", (event.pageY - 28) + "px");
|
205 |
+
})
|
206 |
+
.on("mouseout", function() {
|
207 |
+
d3.select(this).attr("stroke", "#555").attr("stroke-width", 1);
|
208 |
+
tooltip.transition().duration(500).style("opacity", 0);
|
209 |
+
});
|
210 |
+
|
211 |
+
// Add x-axis
|
212 |
+
g.append("g")
|
213 |
+
.attr("transform", `translate(0,${y(0)})`)
|
214 |
+
.call(d3.axisBottom(x).tickFormat(i => `L${i}`))
|
215 |
+
.selectAll("text")
|
216 |
+
.style("font-size", "12px");
|
217 |
+
|
218 |
+
// Add y-axis
|
219 |
+
g.append("g")
|
220 |
+
.call(d3.axisLeft(y).tickFormat(d => `${d.toFixed(1)}%`))
|
221 |
+
.selectAll("text")
|
222 |
+
.style("font-size", "12px");
|
223 |
+
|
224 |
+
// Add title
|
225 |
+
svg.append("text")
|
226 |
+
.attr("x", width / 2)
|
227 |
+
.attr("y", 20)
|
228 |
+
.attr("text-anchor", "middle")
|
229 |
+
.style("font-size", "16px")
|
230 |
+
.style("font-weight", "bold")
|
231 |
+
.text("Layer Contributions to Log Probability (%)");
|
232 |
+
|
233 |
+
// Add x-axis label
|
234 |
+
svg.append("text")
|
235 |
+
.attr("x", width / 2)
|
236 |
+
.attr("y", height - 10)
|
237 |
+
.attr("text-anchor", "middle")
|
238 |
+
.style("font-size", "14px")
|
239 |
+
.text("Layer");
|
240 |
+
|
241 |
+
// Add y-axis label
|
242 |
+
svg.append("text")
|
243 |
+
.attr("transform", "rotate(-90)")
|
244 |
+
.attr("x", -(height / 2))
|
245 |
+
.attr("y", 15)
|
246 |
+
.attr("text-anchor", "middle")
|
247 |
+
.style("font-size", "14px")
|
248 |
+
.text("Contribution %");
|
249 |
+
};
|
250 |
+
|
251 |
+
// Create head contributions heatmap
|
252 |
+
const createHeadHeatmap = () => {
|
253 |
+
// Define visualization parameters
|
254 |
+
const cellSize = 40;
|
255 |
+
const numLayers = headContributions.contributions.length;
|
256 |
+
const numHeads = headContributions.contributions[0].length;
|
257 |
+
const margin = { top: 60, right: 20, bottom: 20, left: 60 };
|
258 |
+
const width = cellSize * numHeads + margin.left + margin.right;
|
259 |
+
const height = cellSize * numLayers + margin.top + margin.bottom;
|
260 |
+
|
261 |
+
// Create SVG
|
262 |
+
const svg = d3.select("#visualization")
|
263 |
+
.append("svg")
|
264 |
+
.attr("width", width)
|
265 |
+
.attr("height", height);
|
266 |
+
|
267 |
+
// Create a group for the heatmap
|
268 |
+
const g = svg.append("g")
|
269 |
+
.attr("transform", `translate(${margin.left},${margin.top})`);
|
270 |
+
|
271 |
+
// Create color scale
|
272 |
+
const colorScale = d3.scaleSequential(d3.interpolateRdBu)
|
273 |
+
.domain([headContributions.max_value, headContributions.min_value]);
|
274 |
+
|
275 |
+
// Create tooltip
|
276 |
+
const tooltip = d3.select("body")
|
277 |
+
.append("div")
|
278 |
+
.attr("class", "tooltip")
|
279 |
+
.style("opacity", 0);
|
280 |
+
|
281 |
+
// Create cells
|
282 |
+
for (let layer = 0; layer < numLayers; layer++) {
|
283 |
+
for (let head = 0; head < numHeads; head++) {
|
284 |
+
const contribution = headContributions.contributions[layer][head];
|
285 |
+
|
286 |
+
g.append("rect")
|
287 |
+
.attr("class", "head-cell")
|
288 |
+
.attr("x", head * cellSize)
|
289 |
+
.attr("y", layer * cellSize)
|
290 |
+
.attr("width", cellSize)
|
291 |
+
.attr("height", cellSize)
|
292 |
+
.attr("fill", colorScale(contribution))
|
293 |
+
.on("mouseover", function(event) {
|
294 |
+
d3.select(this).attr("stroke", "#000").attr("stroke-width", 2);
|
295 |
+
tooltip.transition().duration(200).style("opacity", 1);
|
296 |
+
tooltip.html(`Layer ${layer}, Head ${head}<br>Contribution: ${contribution.toFixed(2)}%`)
|
297 |
+
.style("left", (event.pageX + 10) + "px")
|
298 |
+
.style("top", (event.pageY - 28) + "px");
|
299 |
+
})
|
300 |
+
.on("mouseout", function() {
|
301 |
+
d3.select(this).attr("stroke", "#ddd").attr("stroke-width", 1);
|
302 |
+
tooltip.transition().duration(500).style("opacity", 0);
|
303 |
+
});
|
304 |
+
}
|
305 |
+
}
|
306 |
+
|
307 |
+
// Add layer labels
|
308 |
+
for (let layer = 0; layer < numLayers; layer++) {
|
309 |
+
g.append("text")
|
310 |
+
.attr("class", "layer-label")
|
311 |
+
.attr("x", -10)
|
312 |
+
.attr("y", layer * cellSize + cellSize / 2)
|
313 |
+
.attr("text-anchor", "end")
|
314 |
+
.attr("dominant-baseline", "middle")
|
315 |
+
.text(`L${layer}`);
|
316 |
+
}
|
317 |
+
|
318 |
+
// Add head labels
|
319 |
+
for (let head = 0; head < numHeads; head++) {
|
320 |
+
g.append("text")
|
321 |
+
.attr("class", "head-label")
|
322 |
+
.attr("x", head * cellSize + cellSize / 2)
|
323 |
+
.attr("y", -10)
|
324 |
+
.attr("text-anchor", "middle")
|
325 |
+
.attr("dominant-baseline", "central")
|
326 |
+
.text(`H${head}`);
|
327 |
+
}
|
328 |
+
|
329 |
+
// Add title
|
330 |
+
svg.append("text")
|
331 |
+
.attr("x", width / 2)
|
332 |
+
.attr("y", 20)
|
333 |
+
.attr("text-anchor", "middle")
|
334 |
+
.style("font-size", "16px")
|
335 |
+
.style("font-weight", "bold")
|
336 |
+
.text("Head Contributions to Log Probability (%)");
|
337 |
+
};
|
338 |
+
|
339 |
+
// Create both visualizations
|
340 |
+
createLayerChart();
|
341 |
+
createHeadHeatmap();
|
342 |
+
});
|
343 |
+
</script>
|
344 |
+
{% endif %}
|
345 |
+
</body>
|
346 |
+
</html>
|