SaiMupparaju commited on
Commit
03653db
·
0 Parent(s):

Initial commit for MechVis Hugging Face Space

Browse files
.gitignore ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # Distribution / packaging
7
+ dist/
8
+ build/
9
+ *.egg-info/
10
+
11
+ # Virtual environments
12
+ venv/
13
+ .env/
14
+
15
+ # Environment variables
16
+ .env
17
+
18
+ # IDE files
19
+ .vscode/
20
+ .idea/
21
+
22
+ # Jupyter Notebook
23
+ .ipynb_checkpoints/
24
+
25
+ # Miscellaneous
26
+ .DS_Store
1_4_1_Indirect_Object_Identification_exercises.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ COPY . /code
10
+
11
+ EXPOSE 7860
12
+
13
+ CMD ["python", "app.py"]
Procfile ADDED
@@ -0,0 +1 @@
 
 
1
+ web: python app.py
README.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MechVis: GPT-2 Attention Head Contribution Visualization
2
+
3
+ [![Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/saivamsim26/mechvis)
4
+
5
+ MechVis is a tool for visualizing how attention heads in GPT-2 small contribute to next token predictions. It provides a simple web interface where you can enter text, see what token the model predicts next, and visualize which attention heads contribute most to that prediction.
6
+
7
+ This project is inspired by mechanistic interpretability research on language models, particularly studies of "indirect object identification" in GPT-2 small.
8
+
9
+ ## Features
10
+
11
+ - Input any text prompt and see GPT-2's next token prediction
12
+ - View a heatmap visualization of each attention head's contribution to the predicted token
13
+ - Interactive tooltips showing exact contribution values for each head
14
+ - Simple, clean web interface
15
+
16
+ ## Deployment on Hugging Face Spaces
17
+
18
+ 1. Create a new Space on Hugging Face:
19
+ - Go to https://huggingface.co/spaces
20
+ - Click "Create new Space"
21
+ - Choose "Docker" as the SDK
22
+ - Set the environment variables if needed
23
+
24
+ 2. Upload the following files to your Space:
25
+ - `app.py`
26
+ - `requirements.txt`
27
+ - `Dockerfile`
28
+ - Contents of `templates/` directory
29
+
30
+ The application will automatically deploy and will be available at your Space's URL.
31
+
32
+ ## Local Development
33
+
34
+ To run the application locally:
35
+
36
+ ```bash
37
+ pip install -r requirements.txt
38
+ python app.py
39
+ ```
40
+
41
+ The application will be available at http://localhost:7860
42
+
43
+ ## How to Use
44
+
45
+ 1. Enter a text prompt in the input field
46
+ 2. Click "Predict Next Word"
47
+ 3. View the predicted token, its logit value, and probability
48
+ 4. Explore the heatmap visualization showing each attention head's contribution:
49
+ - Red cells indicate positive contributions to the predicted token
50
+ - Blue cells indicate negative contributions
51
+ - Hover over cells to see exact contribution values
52
+
53
+ ## Understanding the Visualization
54
+
55
+ The visualization shows a 12×12 grid representing all attention heads in GPT-2 small, with:
56
+ - Rows representing layers (0-11)
57
+ - Columns representing heads within each layer (0-11)
58
+ - Color intensity showing the magnitude of contribution
59
+
60
+ This kind of visualization can help identify which attention heads are most important for specific prediction tasks. For example, research has shown that certain heads specialize in tasks like:
61
+ - Name mover heads (e.g., 9.9, 10.0, 9.6)
62
+ - Induction heads (e.g., 5.5, 6.9)
63
+ - S-inhibition heads (e.g., 7.3, 7.9, 8.6, 8.10)
64
+
65
+ ## Example Use Cases
66
+
67
+ 1. **Indirect Object Identification**: Try entering "When John and Mary went to the store, John gave a drink to" and see which heads contribute to predicting "Mary"
68
+
69
+ 2. **Induction Pattern Detection**: Enter repetitive sequences like "The capital of France is Paris. The capital of Germany is" to see induction heads activate
70
+
71
+ 3. **Exploration**: Try various prompts to see how different heads specialize in different linguistic patterns
72
+
73
+ ## References
74
+
75
+ - [Transformer Lens](https://github.com/neelnanda-io/TransformerLens) - Library for transformer interpretability
76
+ - [Indirect Object Identification](https://arxiv.org/abs/2211.00593) - Research on circuits in GPT-2 small
77
+
78
+ ## License
79
+
80
+ This project is licensed under the MIT License - see the LICENSE file for details.
README_HF.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MechVis: GPT-2 Attention Head Visualization
2
+
3
+ This interactive web app allows you to visualize how different attention heads in GPT-2 small contribute to next token predictions.
4
+
5
+ ## How to Use
6
+
7
+ 1. Enter text in the input field (e.g., "When John and Mary went to the store, John gave a drink to")
8
+ 2. Click "Predict Next Word"
9
+ 3. See what token GPT-2 predicts next and explore how each attention head contributes to that prediction
10
+
11
+ ## Features
12
+
13
+ - Next token prediction with GPT-2 small
14
+ - Interactive heatmap showing attention head contributions
15
+ - Layer contribution analysis
16
+ - Hover over cells to see exact contribution values
17
+
18
+ ## Examples to Try
19
+
20
+ - **Indirect Object Identification**: "When John and Mary went to the store, John gave a drink to" (likely predicts "Mary")
21
+ - **Induction Pattern**: "The capital of France is Paris. The capital of Germany is" (likely predicts "Berlin")
22
+
23
+ ## About
24
+
25
+ This project uses [TransformerLens](https://github.com/neelnanda-io/TransformerLens) to access internal model activations and calculate how each attention head contributes to the final logit score of the predicted token.
26
+
27
+ [GitHub Repository](https://github.com/saivamsim26/mechvis)
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from flask import Flask, render_template, request, jsonify
4
+ from transformer_lens import HookedTransformer
5
+ import json
6
+
7
+ app = Flask(__name__)
8
+
9
+ # Load GPT-2 small model
10
+ model = HookedTransformer.from_pretrained(
11
+ "gpt2-small",
12
+ center_unembed=True,
13
+ center_writing_weights=True,
14
+ fold_ln=True,
15
+ refactor_factored_attn_matrices=True,
16
+ )
17
+
18
+ @app.route('/', methods=['GET', 'POST'])
19
+ def index():
20
+ prediction = None
21
+ text = ""
22
+ head_contributions = None
23
+
24
+ if request.method == 'POST':
25
+ text = request.form.get('text', '')
26
+
27
+ if text:
28
+ # Tokenize the input text
29
+ tokens = model.to_tokens(text, prepend_bos=True)
30
+
31
+ # Run the model with cache to get intermediate activations
32
+ logits, cache = model.run_with_cache(tokens)
33
+
34
+ # Get logits for the last token
35
+ last_token_logits = logits[0, -1]
36
+
37
+ # Get the index of the token with the highest logit
38
+ top_token_idx = torch.argmax(last_token_logits).item()
39
+
40
+ # Get the logit value
41
+ top_token_logit = last_token_logits[top_token_idx].item()
42
+
43
+ # Get the probability
44
+ probs = torch.nn.functional.softmax(last_token_logits, dim=-1)
45
+ top_token_prob = probs[top_token_idx].item() * 100 # Convert to percentage
46
+
47
+ # Get the token as a string
48
+ top_token_str = model.to_string([top_token_idx])
49
+
50
+ # Get attention head contributions for the top token
51
+ head_contributions = calculate_head_contributions(cache, top_token_idx, model)
52
+
53
+ prediction = {
54
+ 'token': top_token_str,
55
+ 'logit': top_token_logit,
56
+ 'prob': top_token_prob
57
+ }
58
+
59
+ return render_template('index.html', prediction=prediction, text=text, head_contributions=json.dumps(head_contributions) if head_contributions else None)
60
+
61
+ def calculate_head_contributions(cache, token_idx, model):
62
+ """Calculate the contribution of each attention head to the top token's logit."""
63
+
64
+ # Get all head outputs for the last token
65
+ head_outputs_by_layer = []
66
+ contributions = []
67
+ layer_total_contributions = []
68
+
69
+ # Get the direction in the residual stream that corresponds to the token
70
+ token_direction = model.W_U[:, token_idx].detach()
71
+
72
+ # Calculate contributions for each head
73
+ for layer in range(model.cfg.n_layers):
74
+ # Get the output of each head at the last position
75
+ z = cache["z", layer][0, -1] # [head, d_head]
76
+
77
+ # Apply the OV matrix for each head
78
+ head_outputs = torch.einsum("hd,hdm->hm", z, model.W_O[layer]) # [head, d_model]
79
+
80
+ # Project onto the token direction to get contribution to the logit
81
+ head_contribs = torch.einsum("hm,m->h", head_outputs, token_direction)
82
+
83
+ # Calculate total contribution for this layer
84
+ layer_total = head_contribs.sum().item()
85
+ layer_total_contributions.append(layer_total)
86
+
87
+ # Convert to list for JSON serialization
88
+ layer_contributions = head_contribs.detach().cpu().numpy().tolist()
89
+ contributions.append(layer_contributions)
90
+
91
+ # Calculate total contribution across all heads
92
+ total_contribution = sum([sum(layer_contrib) for layer_contrib in contributions])
93
+
94
+ # Convert contributions to percentage of total
95
+ percentage_contributions = []
96
+ for layer_contributions in contributions:
97
+ percentage_layer = [(contrib / total_contribution) * 100 for contrib in layer_contributions]
98
+ percentage_contributions.append(percentage_layer)
99
+
100
+ # Calculate per-layer contribution percentages
101
+ layer_percentages = [(layer_total / total_contribution) * 100 for layer_total in layer_total_contributions]
102
+
103
+ # Get the max and min values for normalization in visualization
104
+ all_contribs_pct = np.array(percentage_contributions).flatten()
105
+ max_contrib = float(np.max(all_contribs_pct))
106
+ min_contrib = float(np.min(all_contribs_pct))
107
+
108
+ return {
109
+ "contributions": percentage_contributions,
110
+ "max_value": max_contrib,
111
+ "min_value": min_contrib,
112
+ "layer_contributions": layer_percentages
113
+ }
114
+
115
+ if __name__ == '__main__':
116
+ app.run(host="0.0.0.0", port=7860, debug=False)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ flask==2.0.1
2
+ torch==2.0.1
3
+ numpy>=1.21.0
4
+ transformer-lens==1.2.2
5
+ gunicorn==20.1.0
space.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ title: MechVis
2
+ emoji: 📊
3
+ colorFrom: indigo
4
+ colorTo: purple
5
+ sdk: docker
6
+ app_port: 7860
7
+ pinned: false
templates/index.html ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>GPT-2 Next Word Prediction</title>
7
+ <link href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" rel="stylesheet">
8
+ <script src="https://d3js.org/d3.v7.min.js"></script>
9
+ <style>
10
+ body {
11
+ padding: 40px;
12
+ font-family: system-ui, -apple-system, sans-serif;
13
+ }
14
+ .prediction {
15
+ margin-top: 30px;
16
+ padding: 20px;
17
+ background-color: #f8f9fa;
18
+ border-radius: 5px;
19
+ }
20
+ .token {
21
+ font-size: 1.2rem;
22
+ font-weight: bold;
23
+ background-color: #e9ecef;
24
+ padding: 5px 10px;
25
+ border-radius: 4px;
26
+ display: inline-block;
27
+ margin-bottom: 10px;
28
+ }
29
+ #visualization {
30
+ margin-top: 30px;
31
+ width: 100%;
32
+ overflow-x: auto;
33
+ }
34
+ .head-cell {
35
+ stroke: #ddd;
36
+ stroke-width: 1px;
37
+ }
38
+ .layer-label, .head-label {
39
+ font-size: 12px;
40
+ font-weight: bold;
41
+ text-anchor: middle;
42
+ }
43
+ .tooltip {
44
+ position: absolute;
45
+ background-color: rgba(255, 255, 255, 0.9);
46
+ border: 1px solid #ddd;
47
+ padding: 8px;
48
+ border-radius: 4px;
49
+ pointer-events: none;
50
+ font-size: 12px;
51
+ }
52
+ .visualization-container {
53
+ margin-top: 30px;
54
+ background-color: white;
55
+ border-radius: 5px;
56
+ padding: 20px;
57
+ box-shadow: 0 0 10px rgba(0,0,0,0.1);
58
+ }
59
+ .legend {
60
+ margin-top: 15px;
61
+ margin-bottom: 20px;
62
+ }
63
+ .legend-item {
64
+ display: inline-block;
65
+ margin-right: 20px;
66
+ }
67
+ .legend-color {
68
+ display: inline-block;
69
+ width: 20px;
70
+ height: 20px;
71
+ margin-right: 5px;
72
+ vertical-align: middle;
73
+ }
74
+ </style>
75
+ </head>
76
+ <body>
77
+ <div class="container">
78
+ <h1 class="mb-4">GPT-2 Next Word Prediction</h1>
79
+
80
+ <div class="row">
81
+ <div class="col-md-12">
82
+ <form method="POST">
83
+ <div class="mb-3">
84
+ <label for="text" class="form-label">Input Text:</label>
85
+ <textarea class="form-control" id="text" name="text" rows="3" placeholder="Enter text (e.g. 'When John and Mary went to the store, John gave a drink to')" required>{{ text }}</textarea>
86
+ </div>
87
+ <button type="submit" class="btn btn-primary">Predict Next Word</button>
88
+ </form>
89
+ </div>
90
+ </div>
91
+
92
+ {% if prediction %}
93
+ <div class="row">
94
+ <div class="col-md-12">
95
+ <div class="prediction">
96
+ <h3>Prediction Results</h3>
97
+ <p>Input text: <strong>{{ text }}</strong></p>
98
+ <p>Next word: <span class="token">{{ prediction.token }}</span></p>
99
+ <p>Logit value: <strong>{{ "%.4f"|format(prediction.logit) }}</strong></p>
100
+ <p>Probability: <strong>{{ "%.2f"|format(prediction.prob) }}%</strong></p>
101
+ </div>
102
+ </div>
103
+ </div>
104
+
105
+ {% if head_contributions %}
106
+ <div class="row">
107
+ <div class="col-md-12">
108
+ <div class="visualization-container">
109
+ <h3>Layer Contributions to Log Probability</h3>
110
+ <p>This chart shows how each layer in GPT-2 contributes to the log probability of the token "{{ prediction.token }}" (as % of total contribution).</p>
111
+
112
+ <div id="layer-chart"></div>
113
+
114
+ <h3>Attention Head Contributions</h3>
115
+ <p>This visualization shows how each attention head in GPT-2 contributes to the prediction of the token "{{ prediction.token }}" (as % of total contribution).</p>
116
+
117
+ <div class="legend">
118
+ <div class="legend-item">
119
+ <div class="legend-color" style="background-color: #4575b4;"></div>
120
+ <span>Negative contribution %</span>
121
+ </div>
122
+ <div class="legend-item">
123
+ <div class="legend-color" style="background-color: #ffffbf;"></div>
124
+ <span>Neutral (0%)</span>
125
+ </div>
126
+ <div class="legend-item">
127
+ <div class="legend-color" style="background-color: #d73027;"></div>
128
+ <span>Positive contribution %</span>
129
+ </div>
130
+ </div>
131
+
132
+ <div id="visualization"></div>
133
+ </div>
134
+ </div>
135
+ </div>
136
+ {% endif %}
137
+ {% endif %}
138
+ </div>
139
+
140
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"></script>
141
+
142
+ {% if head_contributions %}
143
+ <script>
144
+ document.addEventListener('DOMContentLoaded', function() {
145
+ const headContributions = {{ head_contributions|safe }};
146
+
147
+ // Create layer contributions bar chart
148
+ const createLayerChart = () => {
149
+ const layerContribs = headContributions.layer_contributions;
150
+ const margin = { top: 40, right: 30, bottom: 50, left: 60 };
151
+ const width = Math.min(800, window.innerWidth - 100);
152
+ const height = 300;
153
+
154
+ const svg = d3.select("#layer-chart")
155
+ .append("svg")
156
+ .attr("width", width)
157
+ .attr("height", height);
158
+
159
+ const g = svg.append("g")
160
+ .attr("transform", `translate(${margin.left},${margin.top})`);
161
+
162
+ // Create scales
163
+ const x = d3.scaleBand()
164
+ .domain(d3.range(layerContribs.length))
165
+ .range([0, width - margin.left - margin.right])
166
+ .padding(0.1);
167
+
168
+ const y = d3.scaleLinear()
169
+ .domain([
170
+ Math.min(0, d3.min(layerContribs)),
171
+ Math.max(0, d3.max(layerContribs))
172
+ ])
173
+ .nice()
174
+ .range([height - margin.top - margin.bottom, 0]);
175
+
176
+ // Create color scale - positive is green, negative is purple
177
+ const colorScale = d3.scaleLinear()
178
+ .domain([Math.min(0, d3.min(layerContribs)), 0, Math.max(0, d3.max(layerContribs))])
179
+ .range(["#9467bd", "#f7f7f7", "#2ca02c"]);
180
+
181
+ // Create tooltip
182
+ const tooltip = d3.select("body")
183
+ .append("div")
184
+ .attr("class", "tooltip")
185
+ .style("opacity", 0);
186
+
187
+ // Create bars
188
+ g.selectAll(".bar")
189
+ .data(layerContribs)
190
+ .join("rect")
191
+ .attr("class", "bar")
192
+ .attr("x", (d, i) => x(i))
193
+ .attr("y", d => d >= 0 ? y(d) : y(0))
194
+ .attr("width", x.bandwidth())
195
+ .attr("height", d => Math.abs(y(0) - y(d)))
196
+ .attr("fill", d => colorScale(d))
197
+ .attr("stroke", "#555")
198
+ .attr("stroke-width", 1)
199
+ .on("mouseover", function(event, d) {
200
+ d3.select(this).attr("stroke", "#000").attr("stroke-width", 2);
201
+ tooltip.transition().duration(200).style("opacity", 1);
202
+ tooltip.html(`Layer ${layerContribs.indexOf(d)}<br>Contribution: ${d.toFixed(2)}%`)
203
+ .style("left", (event.pageX + 10) + "px")
204
+ .style("top", (event.pageY - 28) + "px");
205
+ })
206
+ .on("mouseout", function() {
207
+ d3.select(this).attr("stroke", "#555").attr("stroke-width", 1);
208
+ tooltip.transition().duration(500).style("opacity", 0);
209
+ });
210
+
211
+ // Add x-axis
212
+ g.append("g")
213
+ .attr("transform", `translate(0,${y(0)})`)
214
+ .call(d3.axisBottom(x).tickFormat(i => `L${i}`))
215
+ .selectAll("text")
216
+ .style("font-size", "12px");
217
+
218
+ // Add y-axis
219
+ g.append("g")
220
+ .call(d3.axisLeft(y).tickFormat(d => `${d.toFixed(1)}%`))
221
+ .selectAll("text")
222
+ .style("font-size", "12px");
223
+
224
+ // Add title
225
+ svg.append("text")
226
+ .attr("x", width / 2)
227
+ .attr("y", 20)
228
+ .attr("text-anchor", "middle")
229
+ .style("font-size", "16px")
230
+ .style("font-weight", "bold")
231
+ .text("Layer Contributions to Log Probability (%)");
232
+
233
+ // Add x-axis label
234
+ svg.append("text")
235
+ .attr("x", width / 2)
236
+ .attr("y", height - 10)
237
+ .attr("text-anchor", "middle")
238
+ .style("font-size", "14px")
239
+ .text("Layer");
240
+
241
+ // Add y-axis label
242
+ svg.append("text")
243
+ .attr("transform", "rotate(-90)")
244
+ .attr("x", -(height / 2))
245
+ .attr("y", 15)
246
+ .attr("text-anchor", "middle")
247
+ .style("font-size", "14px")
248
+ .text("Contribution %");
249
+ };
250
+
251
+ // Create head contributions heatmap
252
+ const createHeadHeatmap = () => {
253
+ // Define visualization parameters
254
+ const cellSize = 40;
255
+ const numLayers = headContributions.contributions.length;
256
+ const numHeads = headContributions.contributions[0].length;
257
+ const margin = { top: 60, right: 20, bottom: 20, left: 60 };
258
+ const width = cellSize * numHeads + margin.left + margin.right;
259
+ const height = cellSize * numLayers + margin.top + margin.bottom;
260
+
261
+ // Create SVG
262
+ const svg = d3.select("#visualization")
263
+ .append("svg")
264
+ .attr("width", width)
265
+ .attr("height", height);
266
+
267
+ // Create a group for the heatmap
268
+ const g = svg.append("g")
269
+ .attr("transform", `translate(${margin.left},${margin.top})`);
270
+
271
+ // Create color scale
272
+ const colorScale = d3.scaleSequential(d3.interpolateRdBu)
273
+ .domain([headContributions.max_value, headContributions.min_value]);
274
+
275
+ // Create tooltip
276
+ const tooltip = d3.select("body")
277
+ .append("div")
278
+ .attr("class", "tooltip")
279
+ .style("opacity", 0);
280
+
281
+ // Create cells
282
+ for (let layer = 0; layer < numLayers; layer++) {
283
+ for (let head = 0; head < numHeads; head++) {
284
+ const contribution = headContributions.contributions[layer][head];
285
+
286
+ g.append("rect")
287
+ .attr("class", "head-cell")
288
+ .attr("x", head * cellSize)
289
+ .attr("y", layer * cellSize)
290
+ .attr("width", cellSize)
291
+ .attr("height", cellSize)
292
+ .attr("fill", colorScale(contribution))
293
+ .on("mouseover", function(event) {
294
+ d3.select(this).attr("stroke", "#000").attr("stroke-width", 2);
295
+ tooltip.transition().duration(200).style("opacity", 1);
296
+ tooltip.html(`Layer ${layer}, Head ${head}<br>Contribution: ${contribution.toFixed(2)}%`)
297
+ .style("left", (event.pageX + 10) + "px")
298
+ .style("top", (event.pageY - 28) + "px");
299
+ })
300
+ .on("mouseout", function() {
301
+ d3.select(this).attr("stroke", "#ddd").attr("stroke-width", 1);
302
+ tooltip.transition().duration(500).style("opacity", 0);
303
+ });
304
+ }
305
+ }
306
+
307
+ // Add layer labels
308
+ for (let layer = 0; layer < numLayers; layer++) {
309
+ g.append("text")
310
+ .attr("class", "layer-label")
311
+ .attr("x", -10)
312
+ .attr("y", layer * cellSize + cellSize / 2)
313
+ .attr("text-anchor", "end")
314
+ .attr("dominant-baseline", "middle")
315
+ .text(`L${layer}`);
316
+ }
317
+
318
+ // Add head labels
319
+ for (let head = 0; head < numHeads; head++) {
320
+ g.append("text")
321
+ .attr("class", "head-label")
322
+ .attr("x", head * cellSize + cellSize / 2)
323
+ .attr("y", -10)
324
+ .attr("text-anchor", "middle")
325
+ .attr("dominant-baseline", "central")
326
+ .text(`H${head}`);
327
+ }
328
+
329
+ // Add title
330
+ svg.append("text")
331
+ .attr("x", width / 2)
332
+ .attr("y", 20)
333
+ .attr("text-anchor", "middle")
334
+ .style("font-size", "16px")
335
+ .style("font-weight", "bold")
336
+ .text("Head Contributions to Log Probability (%)");
337
+ };
338
+
339
+ // Create both visualizations
340
+ createLayerChart();
341
+ createHeadHeatmap();
342
+ });
343
+ </script>
344
+ {% endif %}
345
+ </body>
346
+ </html>