{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "-c8CAtScHu88" }, "source": [ "# [1.4.1] Indirect Object Identification (exercises)\n", "\n", "> **ARENA [Streamlit Page](https://arena-chapter1-transformer-interp.streamlit.app/21_📚_[1.4.1]_Indirect_Object_Identification)**\n", ">\n", "> **Colab: [exercises](https://colab.research.google.com/github/callummcdougall/ARENA_3.0/blob/main/chapter1_transformer_interp/exercises/part41_indirect_object_identification/1.4.1_Indirect_Object_Identification_exercises.ipynb?t=20250305) | [solutions](https://colab.research.google.com/github/callummcdougall/ARENA_3.0/blob/main/chapter1_transformer_interp/exercises/part41_indirect_object_identification/1.4.1_Indirect_Object_Identification_solutions.ipynb?t=20250305)**\n", "\n", "Please send any problems / bugs on the `#errata` channel in the [Slack group](https://join.slack.com/t/arena-uk/shared_invite/zt-2zick19fl-6GY1yoGaoUozyM3wObwmnQ), and ask any questions on the dedicated channels for this chapter of material.\n", "\n", "You can collapse each section so only the headers are visible, by clicking the arrow symbol on the left hand side of the markdown header cells.\n", "\n", "Links to all other chapters: [(0) Fundamentals](https://arena-chapter0-fundamentals.streamlit.app/), [(1) Transformer Interpretability](https://arena-chapter1-transformer-interp.streamlit.app/), [(2) RL](https://arena-chapter2-rl.streamlit.app/)." ] }, { "cell_type": "markdown", "metadata": { "id": "wZEmBjgXHu89" }, "source": [ "" ] }, { "cell_type": "markdown", "metadata": { "id": "vYkhTw4LHu8-" }, "source": [ "# Introduction" ] }, { "cell_type": "markdown", "metadata": { "id": "VXR4dZ94Hu8-" }, "source": [ "This notebook / document is built around the [Interpretability in the Wild](https://arxiv.org/abs/2211.00593) paper, in which the authors aim to understand the **indirect object identification circuit** in GPT-2 small. This circuit is resposible for the model's ability to complete sentences like `\"John and Mary went to the shops, John gave a bag to\"` with the correct token \"`\" Mary\"`.\n", "\n", "It is loosely divided into different sections, each one with their own flavour. Sections 1, 2 & 3 are derived from Neel Nanda's notebook [Exploratory_Analysis_Demo](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb#scrollTo=WXktSe0CvBdh). The flavour of these exercises is experimental and loose, with a focus on demonstrating what exploratory analysis looks like in practice with the transformerlens library. They skimp on rigour, and instead try to speedrun the process of finding suggestive evidence for this circuit. The code and exercises are simple and generic, but accompanied with a lot of detail about what each stage is doing, and why (plus several optional details and tangents). Section 4 introduces you to the idea of **path patching**, which is a more rigorous and structured way of analysing the model's behaviour. Here, you'll be replicating some of the results of the paper, which will serve to rigorously validate the insights gained from earlier sections. It's the most technically dense of all five sections. Lastly, sections 5 & 6 are much less structured, and have a stronger focus on open-ended exercises & letting you go off and explore for yourself.\n", "\n", "Which exercises you want to do will depend on what you're hoping to get out of these exercises. For example:\n", "\n", "* You want to understand activation patching - **1, 2, 3**\n", "* You want to get a sense of how to do exploratory analysis on a model - **1, 2, 3**\n", "* You want to understand activation and path patching - **1, 2, 3, 4**\n", "* You want to understand the IOI circuit fully, and replicate the paper's key results - **1, 2, 3, 4, 5**\n", "* You want to understand the IOI circuit fully, and replicate the paper's key results (but you already understand activation patching) - **1, 2, 4, 5**\n", "* You want to understand IOI, and then dive deeper e.g. by looking for more circuits in models or investigating anomalies - **1, 2, 3, 4, 5, 6**\n", "\n", "*Note - if you find yourself getting frequent CUDA memory errors, you can periodically call `torch.cuda.empty_cache()` to [free up some memory](https://stackoverflow.com/questions/57858433/how-to-clear-gpu-memory-after-pytorch-model-training-without-restarting-kernel).*\n", "\n", "Each exercise will have a difficulty and importance rating out of 5, as well as an estimated maximum time you should spend on these exercises and sometimes a short annotation. You should interpret the ratings & time estimates relatively (e.g. if you find yourself spending about 50% longer on the exercises than the time estimates, adjust accordingly). Please do skip exercises / look at solutions if you don't feel like they're important enough to be worth doing, and you'd rather get to the good stuff!" ] }, { "cell_type": "markdown", "metadata": { "id": "yaOlw6n6Hu8-" }, "source": [ "## The purpose / structure of these exercises\n", "\n", "At a surface level, these exercises are designed to take you through the indirect object identification circuit. But it's also designed to make you a better interpretability researcher! As a result, most exercises will be doing a combination of:\n", "\n", "1. Showing you some new feature/component of the circuit, and\n", "2. Teaching you how to use tools and interpret results in a broader mech interp context.\n", "\n", "A key idea to have in mind during these exercises is the **spectrum from simpler, more exploratory tools to more rigoruous, complex tools**. On the simpler side, you have something like inspecting attention patterns, which can give a decent (but sometimes misleading) picture of what an attention head is doing. These should be some of the first tools you reach for, and you should be using them a lot even before you have concrete hypotheses about a circuit. On the more rigorous side, you have something like path patching, which is a pretty rigorous and effortful tool that is best used when you already have reasonably concrete hypotheses about a circuit. As we go through the exercises, we'll transition from left to right along this spectrum." ] }, { "cell_type": "markdown", "metadata": { "id": "sK1K_aVVHu8-" }, "source": [ "## The IOI task\n", "\n", "The first step when trying to reverse engineer a circuit in a model is to identify *what* capability we want to reverse engineer. Indirect Object Identification is a task studied in Redwood Research's excellent [Interpretability in the Wild](https://arxiv.org/abs/2211.00593) paper (see [Neel Nanda's interview with the authors](https://www.youtube.com/watch?v=gzwj0jWbvbo) or [Kevin Wang's Twitter thread](https://threadreaderapp.com/thread/1587601532639494146.html) for an overview). The task is to complete sentences like \"When Mary and John went to the store, John gave a drink to\" with \" Mary\" rather than \" John\".\n", "\n", "In the paper they rigorously reverse engineer a 26 head circuit, with 7 separate categories of heads used to perform this capability. The circuit they found roughly breaks down into three parts:\n", "\n", "1. Identify what names are in the sentence\n", "2. Identify which names are duplicated\n", "3. Predict the name that is *not* duplicated\n", "\n", "Why was this task chosen? The authors give a very good explanation for their choice in their [video walkthrough of their paper](https://www.youtube.com/watch?v=gzwj0jWbvbo), which you are encouraged to watch. To be brief, some of the reasons were:\n", "\n", "* This is a fairly common grammatical structure, so we should expect the model to build some circuitry for solving it quite early on (after it's finished with all the more basic stuff, like n-grams, punctuation, induction, and simpler grammatical structures than this one).\n", "* It's easy to measure: the model always puts a much higher probability on the IO and S tokens (i.e. `\" Mary\"` and `\" John\"`) than any others, and this is especially true once the model starts being stripped down to the core part of the circuit we're studying. So we can just take the logit difference between these two tokens, and use this as a metric for how well the model can solve the task.\n", "* It is a crisp and well-defined task, so less likely to be solved in terms of memorisation of a large bag of heuristics (unlike e.g. tasks like \"predict that the number `n+1` will follow `n`, which as Neel mentions in the video walkthrough is actually much more annoying and subtle than it first seems!).\n", "\n", "A terminology note: `IO` will refer to the indirect object (in the example, `\" Mary\"`), `S1` and `S2` will refer to the two instances of the subject token (i.e. `\" John\"`), and `end` will refer to the end token `\" to\"` (because this is the position we take our prediction from, and we don't care about any tokens after this point). We will also sometimes use `S` to refer to the identity of the subject token (rather than referring to the first or second instance in particular)." ] }, { "cell_type": "markdown", "metadata": { "id": "kpk_nMt0Hu8-" }, "source": [ "## Keeping track of your guesses & predictions\n", "\n", "There's a lot to keep track of in these exercises as we work through them. You'll be exposed to new functions and modules from transformerlens, new ways to causally intervene in models, all the while building up your understanding of how the IOI task is performed. The notebook starts off exploratory in nature (lots of plotting and investigation), and gradually moves into more technical details, refined analysis, and replication of the paper's results, as we improve our understanding of the IOI circuit. You are recommended to keep a document or page of notes nearby as you go through these exercises, so you can keep track of the main takeaways from each section, as well as your hypotheses for how the model performs the task, and your ideas for how you might go off and test these hypotheses on your own if the notebook were to suddenly end.\n", "\n", "If you are feeling extremely confused at any point, you can come back to the dropdown below, which contains diagrams explaining how the circuit works. There is also an accompanying intuitive explanation which you might find more helpful. However, I'd recommend you try and go through the notebook unassisted before looking at these.\n", "\n", "
\n", "Intuitive explanation of IOI circuit\n", "\n", "First, let's start with an analogy for how transformers work (you can skip this if you've already read [my post](https://www.lesswrong.com/posts/euam65XjigaCJQkcN/an-analogy-for-understanding-transformers)). Imagine a line of people, who can only look forward. Each person has a token written on their chest, and their goal is to figure out what token the person in front of them is holding. Each person is allowed to pass a question backwards along the line (not forwards), and anyone can choose to reply to that question by passing information forwards to the person who asked. In this case, the sentence is `\"When Mary and John went to the store, John gave a drink to Mary\"`. You are the person holding the `\" to\"` token, and your goal is to figure out that the person in front of him has the `\" Mary\"` token.\n", "\n", "To be clear about how this analogy relates to transformers:\n", "* Each person in the line represents a vector in the residual stream. Initially they just store their own token, but they accrue more information as they ask questions and receive answers (i.e. as components write to the residual stream)\n", "* The operation of an attention head is represented by a question & answer:\n", " * The person who asks is the destination token, the people who answer are the source tokens\n", " * The question is the query vector\n", " * The information *which determines who answers the question* is the key vector\n", " * The information *which gets passed back to the original asker* is the value vector\n", "\n", "Now, here is how the IOI circuit works in this analogy. Each bullet point represents a class of attention heads.\n", "\n", "* The person with the second `\" John\"` token asks the question \"does anyone else hold the name `\" John\"`?\". They get a reply from the first `\" John\"` token, who also gives him their location. So he now knows that `\" John\"` is repeated, and he knows that the first `\" John\"` token is 4th in the sequence.\n", " * These are *Duplicate Token Heads*\n", "* You ask the question \"which names are repeated?\", and you get an answer from the person holding the second `\" John\"` token. You now also know that `\" John\"` is repeated, and where the first `\" John\"` token is.\n", " * These are *S-Inhibition Heads*\n", "* You ask the question \"does anyone have a name that isn't `\" John\"`, and isn't at the 4th position in the sequence?\". You get a reply from the person holding the `\" Mary\"` token, who tells you that they have name `\" Mary\"`. You use this as your prediction.\n", " * These are *Name Mover Heads*\n", "\n", "This is a fine first-pass understanding of how the circuit works. A few other features:\n", "\n", "* The person after the first `\" John\"` (holding `\" went\"`) had previously asked about the identity of the person behind him. So he knows that the 4th person in the sequence holds the `\" John\"` token, meaning he can also reply to the question of the person holding the second `\" John\"` token. *(previous token heads / induction heads)*\n", " * This might not seem necessary, but since previous token heads / induction heads are just a pretty useful thing to have in general, it makes sense that you'd want to make use of this information!\n", "* If for some reason you forget to ask the question \"does anyone have a name that isn't `\" John\"`, and isn't at the 4th position in the sequence?\", then you'll have another chance to do this.\n", " * These are *(Backup Name Mover Heads)*\n", " * Their existance might be partly because transformers are trained with **dropout**. This can make them \"forget\" things, so it's important to have a backup method for recovering that information!\n", "* You want to avoid overconfidence, so you also ask the question \"does anyone have a name that isn't `\" John\"`, and isn't at the 4th position in the sequence?\" another time, in order to ***anti-***predict the response that you get from this question. *(negative name mover heads)*\n", " * Yes, this is as weird as it sounds! The authors speculate that these heads \"hedge\" the predictions, avoiding high cross-entropy loss when making mistakes.\n", "\n", "
\n", "\n", "
\n", "Diagram 1 (simple)\n", "\n", "\n", "\n", "
\n", "\n", "
\n", "Diagram 2 (complex)\n", "\n", "\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "MJWLgA9VHu8-" }, "source": [ "## Content & Learning Objectives\n", "\n", "### 1️⃣ Model & Task Setup\n", "\n", "In this section you'll set up your model, and see how to analyse its performance on the IOI task. You'll also learn how to measure it's performance using tools like logit difference.\n", "\n", "> ##### Learning Objectives\n", ">\n", "> * Understand the IOI task, and why the authors chose to study it\n", "> * Build functions to demonstrate the model's performance on this task\n", "\n", "### 2️⃣ Logit Attribution\n", "\n", "Next, you'll move on to some component attribution: evaluating the importance of each model component for the IOI task. However, this type of analysis is limited to measuring a component's direct effect, as opposed to indirect effect - we'll measure the latter in future sections.\n", "\n", "> ##### Learning Objectives\n", ">\n", "> * Perform direct logit attribution to figure out which heads are writing to the residual stream in a significant way\n", "> * Learn how to use different transformerlens helper functions, which decompose the residual stream in different ways\n", "\n", "### 3️⃣ Activation Patching\n", "\n", "We introduce one of the two important patching tools you'll use during this section: **activation patching**. This can be used to discover which components of a model are important for a particular task, by measuring the changes in our previously-defined task metrics when you patch into a particular component with corrupted input.\n", "\n", "> ##### Learning Objectives\n", ">\n", "> * Understand the idea of activation patching, and how it can be used\n", "> * Implement some of the activation patching helper functinos in transformerlens from scratch (i.e. using hooks)\n", "> * Use activation patching to track the layers & sequence positions in the residual stream where important information is stored and processed\n", "> * By the end of this section, you should be able to draw a rough sketch of the IOI circuit\n", "\n", "### 4️⃣ Path Patching\n", "\n", "Next, we move to path patching, a more refined form of activation patching which examines the importance of particular paths between model components. This will give us a more precise picture of how our circuit works.\n", "\n", "> ##### Learning Objectives\n", ">\n", "> * Understand the idea of path patching, and how it differs from activation patching\n", "> * Implement path patching from scratch (i.e. using hooks)\n", "> * Replicate several of the results in the [IOI paper](https://arxiv.org/abs/2211.00593)\n", "\n", "### 5️⃣ Full Replication: Minimial Circuits and more\n", "\n", "Lastly, we'll do some cleaning up, by replicating some other results from the IOI paper. This includes implementing a complex form of ablation which removes every component from the model except for the ones we've identified from previous analysis, and showing that the performace is recovered. This section is more open-ended and less structured.\n", "\n", "> ##### Learning Objectives\n", ">\n", "> * Replicate most of the other results from the [IOI paper](https://arxiv.org/abs/2211.00593)\n", "> * Practice more open-ended, less guided coding\n", "\n", "### ☆ Bonus / exploring anomalies\n", "\n", "We end with a few suggested bonus exercises for this particular circuit, as well as ideas for capstone projects / paper replications.\n", "\n", "> ##### Learning Objectives\n", ">\n", "> * Explore other parts of the model (e.g. negative name mover heads, and induction heads)\n", "> * Understand the subtleties present in model circuits, and the fact that there are often more parts to a circuit than seem obvious after initial investigation\n", "> * Understand the importance of the three quantitative criteria used by the paper: **faithfulness**, **completeness** and **minimality**" ] }, { "cell_type": "markdown", "metadata": { "id": "srwWIlEUHu8_" }, "source": [ "## Setup code" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "1-xATUaWHu8_" }, "outputs": [], "source": [ "import os\n", "import sys\n", "from pathlib import Path\n", "\n", "IN_COLAB = \"google.colab\" in sys.modules\n", "\n", "chapter = \"chapter1_transformer_interp\"\n", "repo = \"ARENA_3.0\"\n", "branch = \"main\"\n", "\n", "# Install dependencies\n", "try:\n", " import transformer_lens\n", "except:\n", " %pip install transformer_lens==2.11.0 einops jaxtyping git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python\n", "\n", "# Get root directory, handling 3 different cases: (1) Colab, (2) notebook not in ARENA repo, (3) notebook in ARENA repo\n", "root = (\n", " \"/content\"\n", " if IN_COLAB\n", " else \"/root\"\n", " if repo not in os.getcwd()\n", " else str(next(p for p in Path.cwd().parents if p.name == repo))\n", ")\n", "\n", "if Path(root).exists() and not Path(f\"{root}/{chapter}\").exists():\n", " if not IN_COLAB:\n", " !sudo apt-get install unzip\n", " %pip install jupyter ipython --upgrade\n", "\n", " if not os.path.exists(f\"{root}/{chapter}\"):\n", " !wget -P {root} https://github.com/callummcdougall/ARENA_3.0/archive/refs/heads/{branch}.zip\n", " !unzip {root}/{branch}.zip '{repo}-{branch}/{chapter}/exercises/*' -d {root}\n", " !mv {root}/{repo}-{branch}/{chapter} {root}/{chapter}\n", " !rm {root}/{branch}.zip\n", " !rmdir {root}/{repo}-{branch}\n", "\n", "\n", "if f\"{root}/{chapter}/exercises\" not in sys.path:\n", " sys.path.append(f\"{root}/{chapter}/exercises\")\n", "\n", "os.chdir(f\"{root}/{chapter}/exercises\")" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "0dRnIL57Hu8_" }, "outputs": [], "source": [ "import re\n", "import sys\n", "from functools import partial\n", "from itertools import product\n", "from pathlib import Path\n", "from typing import Callable, Literal\n", "\n", "import circuitsvis as cv\n", "import einops\n", "import numpy as np\n", "import plotly.express as px\n", "import torch as t\n", "from IPython.display import HTML, display\n", "from jaxtyping import Bool, Float, Int\n", "from rich import print as rprint\n", "from rich.table import Column, Table\n", "from torch import Tensor\n", "from tqdm.notebook import tqdm\n", "from transformer_lens import ActivationCache, HookedTransformer, utils\n", "from transformer_lens.components import MLP, Embed, LayerNorm, Unembed\n", "from transformer_lens.hook_points import HookPoint\n", "\n", "t.set_grad_enabled(False)\n", "device = t.device(\"mps\" if t.backends.mps.is_available() else \"cuda\" if t.cuda.is_available() else \"cpu\")\n", "\n", "# Make sure exercises are in the path\n", "chapter = \"chapter1_transformer_interp\"\n", "section = \"part41_indirect_object_identification\"\n", "root_dir = next(p for p in Path.cwd().parents if (p / chapter).exists())\n", "exercises_dir = root_dir / chapter / \"exercises\"\n", "section_dir = exercises_dir / section\n", "\n", "import part41_indirect_object_identification.tests as tests\n", "from plotly_utils import bar, imshow, line, scatter\n", "\n", "MAIN = __name__ == \"__main__\"" ] }, { "cell_type": "markdown", "metadata": { "id": "t79_GYU9Hu8_" }, "source": [ "# 1️⃣ Model & Task Setup\n", "\n", "> ##### Learning Objectives\n", ">\n", "> * Understand the IOI task, and why the authors chose to study it\n", "> * Build functions to demonstrate the model's performance on this task" ] }, { "cell_type": "markdown", "metadata": { "id": "zTcYe8q8Hu8_" }, "source": [ "## Loading our model" ] }, { "cell_type": "markdown", "metadata": { "id": "UVXbdZb-Hu8_" }, "source": [ "The first step is to load in our model, GPT-2 Small, a 12 layer and 80M parameter transformer with `HookedTransformer.from_pretrained`. The various flags are simplifications that preserve the model's output but simplify its internals." ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "id": "wvM3s-btHu9A", "outputId": "c2a30245-a188-4530-e148-c41ae704d1c2", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Loaded pretrained model gpt2-small into HookedTransformer\n" ] } ], "source": [ "model = HookedTransformer.from_pretrained(\n", " \"gpt2-small\",\n", " center_unembed=True,\n", " center_writing_weights=True,\n", " fold_ln=True,\n", " refactor_factored_attn_matrices=True,\n", ")" ] }, { "cell_type": "code", "source": [ "# Show column norms are the same (except first few, for fiddly bias reasons)\n", "line([model.W_Q[0, 0].pow(2).sum(0), model.W_K[0, 0].pow(2).sum(0)])\n", "# Show columns are orthogonal (except first few, again)\n", "W_Q_dot_products = einops.einsum(\n", " model.W_Q[0, 0], model.W_Q[0, 0], \"d_model d_head_1, d_model d_head_2 -> d_head_1 d_head_2\"\n", ")\n", "imshow(W_Q_dot_products)" ], "metadata": { "id": "1MhABQ_SN5nL", "outputId": "3d43b14c-0e69-4964-9349-a46e3c294135", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 } }, "execution_count": 4, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", "\n", "\n", "
\n", "
\n", "\n", "" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/html": [ "\n", "\n", "\n", "
\n", "
\n", "\n", "" ] }, "metadata": {} } ] }, { "cell_type": "markdown", "metadata": { "id": "mtYHhtRuHu9A" }, "source": [ "
\n", "Note on refactor_factored_attn_matrices (optional)\n", "\n", "This argument means we redefine the matrices $W_Q$, $W_K$, $W_V$ and $W_O$ in the model (without changing the model's actual behaviour).\n", "\n", "For example, we know that instead of working with $W_Q$ and $W_K$ individually, the only matrix we actually need to use in the model is the low-rank matrix $W_Q W_K^T$ (note that I'm using the convention of matrix multiplication on the right, which matches the code in transformerlens and previous exercises in this series, but doesn't match Anthropic's Mathematical Frameworks paper). So if we perform singular value decomposition $W_Q W_K^T = U S V^T$, then we see that we can just as easily define $W_Q = U \\sqrt{S}$ and $W_K = V \\sqrt{S}$ and use these instead. This means that $W_Q$ and $W_K$ both have orthogonal columns with matching norms. You can investigate this yourself (e.g. using the code below). This is arguably a more interpretable setup, because now there's no obvious asymmetry between the keys and queries.\n", "\n", "There's also some fiddlyness with how biases are handled in this factorisation, which is why the comments above don't hold absolutely (see the documentation for more info).\n", "\n", "```python\n", "# Show column norms are the same (except first few, for fiddly bias reasons)\n", "line([model.W_Q[0, 0].pow(2).sum(0), model.W_K[0, 0].pow(2).sum(0)])\n", "# Show columns are orthogonal (except first few, again)\n", "W_Q_dot_products = einops.einsum(\n", " model.W_Q[0, 0], model.W_Q[0, 0], \"d_model d_head_1, d_model d_head_2 -> d_head_1 d_head_2\"\n", ")\n", "imshow(W_Q_dot_products)\n", "```\n", "\n", "In a similar way, since $W_{OV} = W_V W_O = U S V^T$, we can define $W_V = U S$ and $W_O = V^T$. This is arguably a more interpretable setup, because now $W_O$ is just a rotation, and doesn't change the norm, so $z$ has the same norm as the result of the head.\n", "
\n", "\n", "
\n", "Note on fold_ln, center_unembed and center_writing_weights (optional)\n", "\n", "See link [here](https://github.com/neelnanda-io/TransformerLens/blob/main/further_comments.md#what-is-layernorm-folding-fold_ln) for comments.\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "B3HoxhbyHu9A" }, "source": [ "The next step is to verify that the model can *actually* do the task! Here we use `utils.test_prompt`, and see that the model is significantly better at predicting Mary than John!\n", "\n", "
Asides\n", "\n", "Note: If we were being careful, we'd want to run the model on a range of prompts and find the average performance. We'll do more stuff like this in the fourth section (when we try to replicate some of the paper's results, and take a more rigorous approach).\n", "\n", "`prepend_bos` is a flag to add a BOS (beginning of sequence) to the start of the prompt. GPT-2 was not trained with this, but I find that it often makes model behaviour more stable, as the first token is treated weirdly.\n", "
" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "id": "M-FkXRaSHu9A", "outputId": "d58ac02b-2fce-4650-d7c7-ac4cdcd55e72", "colab": { "base_uri": "https://localhost:8080/", "height": 274 } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']\n", "Tokenized answer: [' Mary']\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "Performance on answer token:\n", "\u001b[1mRank: \u001b[0m\u001b[1;36m0\u001b[0m\u001b[1m Logit: \u001b[0m\u001b[1;36m18.09\u001b[0m\u001b[1m Prob: \u001b[0m\u001b[1;36m70.07\u001b[0m\u001b[1m% Token: | Mary|\u001b[0m\n" ], "text/html": [ "
Performance on answer token:\n",
              "Rank: 0        Logit: 18.09 Prob: 70.07% Token: | Mary|\n",
              "
\n" ] }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|\n", "Top 1th token. Logit: 15.38 Prob: 4.67% Token: | the|\n", "Top 2th token. Logit: 15.35 Prob: 4.54% Token: | John|\n", "Top 3th token. Logit: 15.25 Prob: 4.11% Token: | them|\n", "Top 4th token. Logit: 14.84 Prob: 2.73% Token: | his|\n", "Top 5th token. Logit: 14.06 Prob: 1.24% Token: | her|\n", "Top 6th token. Logit: 13.54 Prob: 0.74% Token: | a|\n", "Top 7th token. Logit: 13.52 Prob: 0.73% Token: | their|\n", "Top 8th token. Logit: 13.13 Prob: 0.49% Token: | Jesus|\n", "Top 9th token. Logit: 12.97 Prob: 0.42% Token: | him|\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Mary'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n" ], "text/html": [ "
Ranks of the answer tokens: [(' Mary', 0)]\n",
              "
\n" ] }, "metadata": {} } ], "source": [ "# Here is where we test on a single prompt\n", "# Result: 70% probability on Mary, as we expect\n", "\n", "example_prompt = \"After John and Mary went to the store, John gave a bottle of milk to\"\n", "example_answer = \" Mary\"\n", "utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "_jNBtLU-Hu9A" }, "source": [ "We now want to find a reference prompt to run the model on. Even though our ultimate goal is to reverse engineer how this behaviour is done in general, often the best way to start out in mechanistic interpretability is by zooming in on a concrete example and understanding it in detail, and only *then* zooming out and verifying that our analysis generalises. In section 3, we'll work with a dataset similar to the one used by the paper authors, but this probably wouldn't be the first thing we reached for if we were just doing initial investigations.\n", "\n", "We'll run the model on 4 instances of this task, each prompt given twice - one with the first name as the indirect object, one with the second name. To make our lives easier, we'll carefully choose prompts with single token names and the corresponding names in the same token positions.\n", "\n", "
Aside on tokenization\n", "\n", "We want models that can take in arbitrary text, but models need to have a fixed vocabulary. So the solution is to define a vocabulary of **tokens** and to deterministically break up arbitrary text into tokens. Tokens are, essentially, subwords, and are determined by finding the most frequent substrings - this means that tokens vary a lot in length and frequency!\n", "\n", "Tokens are a *massive* headache and are one of the most annoying things about reverse engineering language models... Different names will be different numbers of tokens, different prompts will have the relevant tokens at different positions, different prompts will have different total numbers of tokens, etc. Language models often devote significant amounts of parameters in early layers to convert inputs from tokens to a more sensible internal format (and do the reverse in later layers). You really, really want to avoid needing to think about tokenization wherever possible when doing exploratory analysis (though, of course, it's relevant later when trying to flesh out your analysis and make it rigorous!). HookedTransformer comes with several helper methods to deal with tokens: `to_tokens, to_string, to_str_tokens, to_single_token, get_token_position`\n", "\n", "**Exercise:** I recommend using `model.to_str_tokens` to explore how the model tokenizes different strings. In particular, try adding or removing spaces at the start, or changing capitalization - these change tokenization!
" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "mDZhT5B_Hu9A", "outputId": "94e0e835-1a1c-4b76-a1c2-e2185c4d6794", "colab": { "base_uri": "https://localhost:8080/", "height": 678 } }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "\u001b[1m[\u001b[0m\n", " \u001b[32m'When John and Mary went to the shops, John gave the bag to'\u001b[0m,\n", " \u001b[32m'When John and Mary went to the shops, Mary gave the bag to'\u001b[0m,\n", " \u001b[32m'When Tom and James went to the park, James gave the ball to'\u001b[0m,\n", " \u001b[32m'When Tom and James went to the park, Tom gave the ball to'\u001b[0m,\n", " \u001b[32m'When Dan and Sid went to the shops, Sid gave an apple to'\u001b[0m,\n", " \u001b[32m'When Dan and Sid went to the shops, Dan gave an apple to'\u001b[0m,\n", " \u001b[32m'After Martin and Amy went to the park, Amy gave a drink to'\u001b[0m,\n", " \u001b[32m'After Martin and Amy went to the park, Martin gave a drink to'\u001b[0m\n", "\u001b[1m]\u001b[0m\n" ], "text/html": [ "
[\n",
              "    'When John and Mary went to the shops, John gave the bag to',\n",
              "    'When John and Mary went to the shops, Mary gave the bag to',\n",
              "    'When Tom and James went to the park, James gave the ball to',\n",
              "    'When Tom and James went to the park, Tom gave the ball to',\n",
              "    'When Dan and Sid went to the shops, Sid gave an apple to',\n",
              "    'When Dan and Sid went to the shops, Dan gave an apple to',\n",
              "    'After Martin and Amy went to the park, Amy gave a drink to',\n",
              "    'After Martin and Amy went to the park, Martin gave a drink to'\n",
              "]\n",
              "
\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[1m[\u001b[0m\n", " \u001b[1m(\u001b[0m\u001b[32m' Mary'\u001b[0m, \u001b[32m' John'\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[1m(\u001b[0m\u001b[32m' John'\u001b[0m, \u001b[32m' Mary'\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[1m(\u001b[0m\u001b[32m' Tom'\u001b[0m, \u001b[32m' James'\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[1m(\u001b[0m\u001b[32m' James'\u001b[0m, \u001b[32m' Tom'\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[1m(\u001b[0m\u001b[32m' Dan'\u001b[0m, \u001b[32m' Sid'\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[1m(\u001b[0m\u001b[32m' Sid'\u001b[0m, \u001b[32m' Dan'\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[1m(\u001b[0m\u001b[32m' Martin'\u001b[0m, \u001b[32m' Amy'\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[1m(\u001b[0m\u001b[32m' Amy'\u001b[0m, \u001b[32m' Martin'\u001b[0m\u001b[1m)\u001b[0m\n", "\u001b[1m]\u001b[0m\n" ], "text/html": [ "
[\n",
              "    (' Mary', ' John'),\n",
              "    (' John', ' Mary'),\n",
              "    (' Tom', ' James'),\n",
              "    (' James', ' Tom'),\n",
              "    (' Dan', ' Sid'),\n",
              "    (' Sid', ' Dan'),\n",
              "    (' Martin', ' Amy'),\n",
              "    (' Amy', ' Martin')\n",
              "]\n",
              "
\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m5335\u001b[0m, \u001b[1;36m1757\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m1757\u001b[0m, \u001b[1;36m5335\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m4186\u001b[0m, \u001b[1;36m3700\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m3700\u001b[0m, \u001b[1;36m4186\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m6035\u001b[0m, \u001b[1;36m15686\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m15686\u001b[0m, \u001b[1;36m6035\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m5780\u001b[0m, \u001b[1;36m14235\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m14235\u001b[0m, \u001b[1;36m5780\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" ], "text/html": [ "
tensor([[ 5335,  1757],\n",
              "        [ 1757,  5335],\n",
              "        [ 4186,  3700],\n",
              "        [ 3700,  4186],\n",
              "        [ 6035, 15686],\n",
              "        [15686,  6035],\n",
              "        [ 5780, 14235],\n",
              "        [14235,  5780]])\n",
              "
\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[3m Prompts & Answers: \u001b[0m\n", "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mPrompt \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mCorrect \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mIncorrect\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━┩\n", "│ When John and Mary went to the shops, John gave the bag to │ ' Mary' │ ' John' │\n", "│ When John and Mary went to the shops, Mary gave the bag to │ ' John' │ ' Mary' │\n", "│ When Tom and James went to the park, James gave the ball to │ ' Tom' │ ' James' │\n", "│ When Tom and James went to the park, Tom gave the ball to │ ' James' │ ' Tom' │\n", "│ When Dan and Sid went to the shops, Sid gave an apple to │ ' Dan' │ ' Sid' │\n", "│ When Dan and Sid went to the shops, Dan gave an apple to │ ' Sid' │ ' Dan' │\n", "│ After Martin and Amy went to the park, Amy gave a drink to │ ' Martin' │ ' Amy' │\n", "│ After Martin and Amy went to the park, Martin gave a drink to │ ' Amy' │ ' Martin' │\n", "└───────────────────────────────────────────────────────────────┴───────────┴───────────┘\n" ], "text/html": [ "
                                   Prompts & Answers:                                    \n",
              "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┓\n",
              "┃ Prompt                                                         Correct    Incorrect ┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━┩\n",
              "│ When John and Mary went to the shops, John gave the bag to    │ ' Mary'   │ ' John'   │\n",
              "│ When John and Mary went to the shops, Mary gave the bag to    │ ' John'   │ ' Mary'   │\n",
              "│ When Tom and James went to the park, James gave the ball to   │ ' Tom'    │ ' James'  │\n",
              "│ When Tom and James went to the park, Tom gave the ball to     │ ' James'  │ ' Tom'    │\n",
              "│ When Dan and Sid went to the shops, Sid gave an apple to      │ ' Dan'    │ ' Sid'    │\n",
              "│ When Dan and Sid went to the shops, Dan gave an apple to      │ ' Sid'    │ ' Dan'    │\n",
              "│ After Martin and Amy went to the park, Amy gave a drink to    │ ' Martin' │ ' Amy'    │\n",
              "│ After Martin and Amy went to the park, Martin gave a drink to │ ' Amy'    │ ' Martin' │\n",
              "└───────────────────────────────────────────────────────────────┴───────────┴───────────┘\n",
              "
\n" ] }, "metadata": {} } ], "source": [ "prompt_format = [\n", " \"When John and Mary went to the shops,{} gave the bag to\",\n", " \"When Tom and James went to the park,{} gave the ball to\",\n", " \"When Dan and Sid went to the shops,{} gave an apple to\",\n", " \"After Martin and Amy went to the park,{} gave a drink to\",\n", "]\n", "name_pairs = [\n", " (\" Mary\", \" John\"),\n", " (\" Tom\", \" James\"),\n", " (\" Dan\", \" Sid\"),\n", " (\" Martin\", \" Amy\"),\n", "]\n", "\n", "# Define 8 prompts, in 4 groups of 2 (with adjacent prompts having answers swapped)\n", "prompts = [prompt.format(name) for (prompt, names) in zip(prompt_format, name_pairs) for name in names[::-1]]\n", "# Define the answers for each prompt, in the form (correct, incorrect)\n", "answers = [names[::i] for names in name_pairs for i in (1, -1)]\n", "# Define the answer tokens (same shape as the answers)\n", "answer_tokens = t.concat([model.to_tokens(names, prepend_bos=False).T for names in answers])\n", "\n", "rprint(prompts)\n", "rprint(answers)\n", "rprint(answer_tokens)\n", "\n", "table = Table(\"Prompt\", \"Correct\", \"Incorrect\", title=\"Prompts & Answers:\")\n", "\n", "for prompt, answer in zip(prompts, answers):\n", " table.add_row(prompt, repr(answer[0]), repr(answer[1]))\n", "\n", "rprint(table)" ] }, { "cell_type": "markdown", "metadata": { "id": "cqw_akj8Hu9A" }, "source": [ "
\n", "Aside - the rich library\n", "\n", "The outputs above were created by `rich`, a fun library which prints things in nice formats. It has functions like `rich.table.Table`, which are very easy to use but can produce visually clear outputs which are sometimes useful.\n", "\n", "You can also color the columns of a table, by using the `rich.table.Column` argument with the `style` parameter:\n", "\n", "```python\n", "cols = [\n", " \"Prompt\",\n", " Column(\"Correct\", style=\"rgb(0,200,0) bold\"),\n", " Column(\"Incorrect\", style=\"rgb(255,0,0) bold\"),\n", "]\n", "table = Table(*cols, title=\"Prompts & Answers:\")\n", "\n", "for prompt, answer in zip(prompts, answers):\n", " table.add_row(prompt, repr(answer[0]), repr(answer[1]))\n", "\n", "rprint(table)\n", "```\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "P6K-ikO1Hu9A" }, "source": [ "We now run the model on these prompts and use `run_with_cache` to get both the logits and a cache of all internal activations for later analysis." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "id": "8xTESScHHu9A", "outputId": "18cdcfda-e007-4c6b-c5f5-75456dd2b74f", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([8, 15, 50257])\n" ] } ], "source": [ "tokens = model.to_tokens(prompts, prepend_bos=True)\n", "# Move the tokens to the GPU\n", "tokens = tokens.to(device)\n", "# Run the model and cache all activations\n", "original_logits, cache = model.run_with_cache(tokens)\n", "print(original_logits.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "EPYgDpsDHu9A" }, "source": [ "We'll later be evaluating how model performance differs upon performing various interventions, so it's useful to have a metric to measure model performance. Our metric here will be the **logit difference**, the difference in logit between the indirect object's name and the subject's name (eg, `logit(Mary) - logit(John)`)." ] }, { "cell_type": "markdown", "metadata": { "id": "MiIZiQzsHu9A" }, "source": [ "### Exercise - implement the performance evaluation function\n", "\n", "> ```yaml\n", "> Difficulty: 🔴🔴🔴⚪⚪\n", "> Importance: 🔵🔵🔵🔵⚪\n", ">\n", "> You should spend up to 10-15 minutes on this exercise.\n", "> It's important to understand exactly what this function is computing, and why it matters.\n", "> ```\n", "\n", "This function should take in your model's logit output (shape `(batch, seq, d_vocab)`), and the array of answer tokens (shape `(batch, 2)`, containing the token ids of correct and incorrect answers respectively for each sequence), and return the logit difference as described above. If `per_prompt` is False, then it should take the mean over the batch dimension, if not then it should return an array of length `batch`." ] }, { "cell_type": "code", "execution_count": 31, "metadata": { "id": "jpeLz0FpHu9A", "outputId": "8316d8f3-8956-48a2-d745-0ae1948026e5", "colab": { "base_uri": "https://localhost:8080/", "height": 279 } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "All tests in `test_logits_to_ave_logit_diff` passed!\n", "Per prompt logit difference: tensor([3.3367, 3.2016, 2.7095, 3.7974, 1.7204, 5.2812, 2.6008, 5.7674])\n", "Average logit difference: tensor(3.5519)\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[3m Logit differences \u001b[0m\n", "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mPrompt \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mCorrect \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mIncorrect\u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mLogit Difference\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩\n", "│ When John and Mary went to the shops, John gave the bag to │\u001b[1;38;2;0;200;0m \u001b[0m\u001b[1;38;2;0;200;0m' Mary' \u001b[0m\u001b[1;38;2;0;200;0m \u001b[0m│\u001b[1;38;2;255;0;0m \u001b[0m\u001b[1;38;2;255;0;0m' John' \u001b[0m\u001b[1;38;2;255;0;0m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m3.337 \u001b[0m\u001b[1m \u001b[0m│\n", "│ When John and Mary went to the shops, Mary gave the bag to │\u001b[1;38;2;0;200;0m \u001b[0m\u001b[1;38;2;0;200;0m' John' \u001b[0m\u001b[1;38;2;0;200;0m \u001b[0m│\u001b[1;38;2;255;0;0m \u001b[0m\u001b[1;38;2;255;0;0m' Mary' \u001b[0m\u001b[1;38;2;255;0;0m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m3.202 \u001b[0m\u001b[1m \u001b[0m│\n", "│ When Tom and James went to the park, James gave the ball to │\u001b[1;38;2;0;200;0m \u001b[0m\u001b[1;38;2;0;200;0m' Tom' \u001b[0m\u001b[1;38;2;0;200;0m \u001b[0m│\u001b[1;38;2;255;0;0m \u001b[0m\u001b[1;38;2;255;0;0m' James' \u001b[0m\u001b[1;38;2;255;0;0m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m2.709 \u001b[0m\u001b[1m \u001b[0m│\n", "│ When Tom and James went to the park, Tom gave the ball to │\u001b[1;38;2;0;200;0m \u001b[0m\u001b[1;38;2;0;200;0m' James' \u001b[0m\u001b[1;38;2;0;200;0m \u001b[0m│\u001b[1;38;2;255;0;0m \u001b[0m\u001b[1;38;2;255;0;0m' Tom' \u001b[0m\u001b[1;38;2;255;0;0m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m3.797 \u001b[0m\u001b[1m \u001b[0m│\n", "│ When Dan and Sid went to the shops, Sid gave an apple to │\u001b[1;38;2;0;200;0m \u001b[0m\u001b[1;38;2;0;200;0m' Dan' \u001b[0m\u001b[1;38;2;0;200;0m \u001b[0m│\u001b[1;38;2;255;0;0m \u001b[0m\u001b[1;38;2;255;0;0m' Sid' \u001b[0m\u001b[1;38;2;255;0;0m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m1.720 \u001b[0m\u001b[1m \u001b[0m│\n", "│ When Dan and Sid went to the shops, Dan gave an apple to │\u001b[1;38;2;0;200;0m \u001b[0m\u001b[1;38;2;0;200;0m' Sid' \u001b[0m\u001b[1;38;2;0;200;0m \u001b[0m│\u001b[1;38;2;255;0;0m \u001b[0m\u001b[1;38;2;255;0;0m' Dan' \u001b[0m\u001b[1;38;2;255;0;0m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m5.281 \u001b[0m\u001b[1m \u001b[0m│\n", "│ After Martin and Amy went to the park, Amy gave a drink to │\u001b[1;38;2;0;200;0m \u001b[0m\u001b[1;38;2;0;200;0m' Martin'\u001b[0m\u001b[1;38;2;0;200;0m \u001b[0m│\u001b[1;38;2;255;0;0m \u001b[0m\u001b[1;38;2;255;0;0m' Amy' \u001b[0m\u001b[1;38;2;255;0;0m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m2.601 \u001b[0m\u001b[1m \u001b[0m│\n", "│ After Martin and Amy went to the park, Martin gave a drink to │\u001b[1;38;2;0;200;0m \u001b[0m\u001b[1;38;2;0;200;0m' Amy' \u001b[0m\u001b[1;38;2;0;200;0m \u001b[0m│\u001b[1;38;2;255;0;0m \u001b[0m\u001b[1;38;2;255;0;0m' Martin'\u001b[0m\u001b[1;38;2;255;0;0m \u001b[0m│\u001b[1m \u001b[0m\u001b[1m5.767 \u001b[0m\u001b[1m \u001b[0m│\n", "└───────────────────────────────────────────────────────────────┴───────────┴───────────┴──────────────────┘\n" ], "text/html": [ "
                                             Logit differences                                              \n",
              "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓\n",
              "┃ Prompt                                                         Correct    Incorrect  Logit Difference ┃\n",
              "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩\n",
              "│ When John and Mary went to the shops, John gave the bag to    │ ' Mary'    ' John'    3.337            │\n",
              "│ When John and Mary went to the shops, Mary gave the bag to    │ ' John'    ' Mary'    3.202            │\n",
              "│ When Tom and James went to the park, James gave the ball to   │ ' Tom'     ' James'   2.709            │\n",
              "│ When Tom and James went to the park, Tom gave the ball to     │ ' James'   ' Tom'     3.797            │\n",
              "│ When Dan and Sid went to the shops, Sid gave an apple to      │ ' Dan'     ' Sid'     1.720            │\n",
              "│ When Dan and Sid went to the shops, Dan gave an apple to      │ ' Sid'     ' Dan'     5.281            │\n",
              "│ After Martin and Amy went to the park, Amy gave a drink to    │ ' Martin'  ' Amy'     2.601            │\n",
              "│ After Martin and Amy went to the park, Martin gave a drink to │ ' Amy'     ' Martin'  5.767            │\n",
              "└───────────────────────────────────────────────────────────────┴───────────┴───────────┴──────────────────┘\n",
              "
\n" ] }, "metadata": {} } ], "source": [ "def logits_to_ave_logit_diff(\n", " logits: Float[Tensor, \"batch seq d_vocab\"],\n", " answer_tokens: Float[Tensor, \"batch 2\"] = answer_tokens,\n", " per_prompt: bool = False,\n", ") -> Float[Tensor, \"*batch\"]:\n", " \"\"\"\n", " Returns logit difference between the correct and incorrect answer.\n", "\n", " If per_prompt=True, return the array of differences rather than the average.\n", " \"\"\"\n", " batch_idx = t.arange(logits.size(0))\n", "\n", " correct = logits[:, -1, :][batch_idx, answer_tokens[:, 0]]\n", " incorrect = logits[:, -1, :][batch_idx, answer_tokens[:, 1]]\n", " if per_prompt: return correct - incorrect\n", " return t.mean(correct-incorrect)\n", "\n", "\n", "tests.test_logits_to_ave_logit_diff(logits_to_ave_logit_diff)\n", "\n", "original_per_prompt_diff = logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)\n", "print(\"Per prompt logit difference:\", original_per_prompt_diff)\n", "original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)\n", "print(\"Average logit difference:\", original_average_logit_diff)\n", "\n", "cols = [\n", " \"Prompt\",\n", " Column(\"Correct\", style=\"rgb(0,200,0) bold\"),\n", " Column(\"Incorrect\", style=\"rgb(255,0,0) bold\"),\n", " Column(\"Logit Difference\", style=\"bold\"),\n", "]\n", "table = Table(*cols, title=\"Logit differences\")\n", "\n", "for prompt, answer, logit_diff in zip(prompts, answers, original_per_prompt_diff):\n", " table.add_row(prompt, repr(answer[0]), repr(answer[1]), f\"{logit_diff.item():.3f}\")\n", "\n", "rprint(table)" ] }, { "cell_type": "markdown", "metadata": { "id": "cSu6RMvSHu9A" }, "source": [ "
Solution\n", "\n", "```python\n", "def logits_to_ave_logit_diff(\n", " logits: Float[Tensor, \"batch seq d_vocab\"],\n", " answer_tokens: Float[Tensor, \"batch 2\"] = answer_tokens,\n", " per_prompt: bool = False,\n", ") -> Float[Tensor, \"*batch\"]:\n", " \"\"\"\n", " Returns logit difference between the correct and incorrect answer.\n", "\n", " If per_prompt=True, return the array of differences rather than the average.\n", " \"\"\"\n", " # Only the final logits are relevant for the answer\n", " final_logits: Float[Tensor, \"batch d_vocab\"] = logits[:, -1, :]\n", " # Get the logits corresponding to the indirect object / subject tokens respectively\n", " answer_logits: Float[Tensor, \"batch 2\"] = final_logits.gather(dim=-1, index=answer_tokens)\n", " # Find logit difference\n", " correct_logits, incorrect_logits = answer_logits.unbind(dim=-1)\n", " answer_logit_diff = correct_logits - incorrect_logits\n", " return answer_logit_diff if per_prompt else answer_logit_diff.mean()\n", "```\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "6OT--NwoHu9A" }, "source": [ "## Brainstorm What's Actually Going On" ] }, { "cell_type": "markdown", "metadata": { "id": "sMvzi9QsHu9A" }, "source": [ "Before diving into running experiments, it's often useful to spend some time actually reasoning about how the behaviour in question could be implemented in the transformer. **This is optional, and you'll likely get the most out of engaging with this section if you have a decent understanding already of what a transformer is and how it works!**\n", "\n", "You don't have to do this and forming hypotheses after exploration is also reasonable, but I think it's often easier to explore and interpret results with some grounding in what you might find. In this particular case, I'm cheating somewhat, since I know the answer, but I'm trying to simulate the process of reasoning about it!\n", "\n", "Note that often your hypothesis will be wrong in some ways and often be completely off. We're doing science here, and the goal is to understand how the model *actually* works, and to form true beliefs! There are two separate traps here at two extremes that it's worth tracking:\n", "* Confusion: Having no hypotheses at all, getting a lot of data and not knowing what to do with it, and just floundering around\n", "* Dogmatism: Being overconfident in an incorrect hypothesis and being unwilling to let go of it when reality contradicts you, or flinching away from running the experiments that might disconfirm it.\n", "\n", "**Exercise:** Spend some time thinking through how you might imagine this behaviour being implemented in a transformer. Try to think through this for yourself before reading through my thoughts!\n", "\n", "
(*) My reasoning\n", "\n", "

Brainstorming:

\n", "\n", "So, what's hard about the task? Let's focus on the concrete example of the first prompt, `\"When John and Mary went to the shops, John gave the bag to\" -> \" Mary\"`.\n", "\n", "A good starting point is thinking though whether a tiny model could do this, e.g. a 1L Attn-Only model. I'm pretty sure the answer is no! Attention is really good at the primitive operations of looking nearby, or copying information. I can believe a tiny model could figure out that at `to` it should look for names and predict that those names came next (e.g. the skip trigram \" John...to -> John\"). But it's much harder to tell how many of each previous name there are - attending to each copy of John will look exactly the same as attending to a single John token. So this will be pretty hard to figure out on the \" to\" token!\n", "\n", "The natural place to break this symmetry is on the second `\" John\"` token - telling whether there is an earlier copy of the current token should be a much easier task. So I might expect there to be a head which detects duplicate tokens on the second `\" John\"` token, and then another head which moves that information from the second `\" John` token to the `\" to\"` token.\n", "\n", "The model then needs to learn to predict `\" Mary\"` and not `\" John\"`. I can see two natural ways to do this:\n", "1. Detect all preceding names and move this information to \" to\" and then delete the any name corresponding to the duplicate token feature. This feels easier done with a non-linearity, since precisely cancelling out vectors is hard, so I'd imagine an MLP layer deletes the `\" John\"` direction of the residual stream.\n", "2. Have a head which attends to all previous names, but where the duplicate token features inhibit it from attending to specific names. So this only attends to Mary. And then the output of this head maps to the logits.\n", "\n", "
\n", "Spoiler - which one of these two is correct\n", "\n", "It's the second one.\n", "
\n", "\n", "

Experiment Ideas

\n", "\n", "A test that could distinguish these two is to look at which components of the model add directly to the logits - if it's mostly attention heads which attend to `\" Mary\"` and to neither `\" John\"` it's probably hypothesis 2, if it's mostly MLPs it's probably hypothesis 1.\n", "\n", "And we should be able to identify duplicate token heads by finding ones which attend from `\" John\"` to `\" John\"`, and whose outputs are then moved to the `\" to\"` token by V-Composition with another head (Spoiler: It's more complicated than that!)\n", "\n", "Note that all of the above reasoning is very simplistic and could easily break in a real model! There'll be significant parts of the model that figure out whether to use this circuit at all (we don't want to inhibit duplicated names when, e.g. figuring out what goes at the start of the next sentence), and may be parts towards the end of the model that do \"post-processing\" just before the final output. But it's a good starting point for thinking about what's going on.\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "RjXLQuA4Hu9A" }, "source": [ "# 2️⃣ Logit Attribution\n", "\n", "> ##### Learning Objectives\n", ">\n", "> * Perform direct logit attribution to figure out which heads are writing to the residual stream in a significant way\n", "> * Learn how to use different transformerlens helper functions, which decompose the residual stream in different ways" ] }, { "cell_type": "markdown", "metadata": { "id": "neSSR9s4Hu9A" }, "source": [ "## Direct Logit Attribution" ] }, { "cell_type": "markdown", "metadata": { "id": "UFuuixAcHu9A" }, "source": [ "The easiest part of the model to understand is the output - this is what the model is trained to optimize, and so it can always be directly interpreted! Often the right approach to reverse engineering a circuit is to start at the end, understand how the model produces the right answer, and to then work backwards (you will have seen this if you went through the balanced bracket classifier task, and in fact if you did then this section will probably be quite familiar to you and you should feel free to just skim through it). The main technique used to do this is called **direct logit attribution**\n", "\n", "**Background:** The central object of a transformer is the **residual stream**. This is the sum of the outputs of each layer and of the original token and positional embedding. Importantly, this means that any linear function of the residual stream can be perfectly decomposed into the contribution of each layer of the transformer. Further, each attention layer's output can be broken down into the sum of the output of each head (See [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html) for details), and each MLP layer's output can be broken down into the sum of the output of each neuron (and a bias term for each layer).\n", "\n", "The logits of a model are `logits=Unembed(LayerNorm(final_residual_stream))`. The Unembed is a linear map, and LayerNorm is approximately a linear map, so we can decompose the logits into the sum of the contributions of each component, and look at which components contribute the most to the logit of the correct token! This is called **direct logit attribution**. Here we look at the direct attribution to the logit difference!" ] }, { "cell_type": "markdown", "metadata": { "id": "D3Q1bJVyHu9A" }, "source": [ "### Background and motivation of the logit difference\n", "\n", "Logit difference is actually a *really* nice and elegant metric and is a particularly nice aspect of the setup of Indirect Object Identification. In general, there are two natural ways to interpret the model's outputs: the output logits, or the output log probabilities (or probabilities).\n", "\n", "The logits are much nicer and easier to understand, as noted above. However, the model is trained to optimize the cross-entropy loss (the average of log probability of the correct token). This means it does not directly optimize the logits, and indeed if the model adds an arbitrary constant to every logit, the log probabilities are unchanged.\n", "\n", "But we have:\n", "\n", "```\n", "log_probs == logits.log_softmax(dim=-1) == logits - logsumexp(logits)\n", "```\n", "\n", "and because they differ by a constant, we have:\n", "\n", "```\n", "log_probs(\" Mary\") - log_probs(\" John\") = logits(\" Mary\") - logits(\" John\")\n", "```\n", "\n", "- the ability to add an arbitrary constant cancels out!\n", "\n", "
\n", "Technical details (if this equivalence doesn't seem obvious to you)\n", "\n", "Let $\\vec{\\textbf{x}}$ be the logits, $\\vec{\\textbf{L}}$ be the log probs, and $\\vec{\\textbf{p}}$ be the probs. Then we have the following relations:\n", "\n", "$$\n", "p_i = \\operatorname{softmax}(\\vec{\\textbf{x}})_i = \\frac{e^{x_i}}{\\sum_{i=1}^n e^{x_i}}\n", "$$\n", "\n", "and:\n", "\n", "$$\n", "L_i = \\log p_i\n", "$$\n", "\n", "Combining these, we get:\n", "\n", "$$\n", "L_i = \\log \\frac{e^{x_i}}{\\sum_{j=1}^n e^{x_j}} = x_i - \\log \\sum_{j=1}^n e^{x_j}\n", "$$\n", "\n", "Notice that the sum term on the right hand side is the same for all $i$, so we get:\n", "\n", "$$\n", "L_i - L_j = x_i - x_j\n", "$$\n", "\n", "in other words, the logit diff $x_i - x_j$ is the same as the log prob diff. This motivates the choice of logit diff as our choice of metric (since the model is directly training to make the log prob of the correct token large, and all other log probs small).\n", "\n", "
\n", "\n", "Further, the metric helps us isolate the precise capability we care about - figuring out *which* name is the Indirect Object. There are many other components of the task - deciding whether to return an article (the) or pronoun (her) or name, realising that the sentence wants a person next at all, etc. By taking the logit difference we control for all of that.\n", "\n", "Our metric is further refined, because each prompt is repeated twice, for each possible indirect object. This controls for irrelevant behaviour such as the model learning that John is a more frequent token than Mary (this actually happens! The final layernorm bias increases the John logit by 1 relative to the Mary logit). Another way to handle this would be to use a large enough dataset (with names randomly chosen) that this effect is averaged out, which is what we'll do in section 3." ] }, { "cell_type": "markdown", "metadata": { "id": "5RJsbbS7Hu9B" }, "source": [ "
Ignoring LayerNorm\n", "\n", "LayerNorm is an analogous normalization technique to BatchNorm (that's friendlier to massive parallelization) that transformers use. Every time a transformer layer reads information from the residual stream, it applies a LayerNorm to normalize the vector at each position (translating to set the mean to 0 and scaling to set the variance to 1) and then applying a learned vector of weights and biases to scale and translate the normalized vector. This is *almost* a linear map, apart from the scaling step, because that divides by the norm of the vector and the norm is not a linear function. (The `fold_ln` flag when loading a model factors out all the linear parts).\n", "\n", "But if we fixed the scale factor, the LayerNorm would be fully linear. And the scale of the residual stream is a global property that's a function of *all* components of the stream, while in practice there is normally just a few directions relevant to any particular component, so in practice this is an acceptable approximation. So when doing direct logit attribution we use the `apply_ln` flag on the `cache` to apply the global layernorm scaling factor to each constant. See [my clean GPT-2 implementation](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/clean-transformer-demo/Clean_Transformer_Demo.ipynb#scrollTo=Clean_Transformer_Implementation) for more on LayerNorm.\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "Lz0ixUUaHu9B" }, "source": [ "### Logit diff directions" ] }, { "cell_type": "markdown", "metadata": { "id": "V6AMRqMaHu9B" }, "source": [ "***Getting an output logit is equivalent to projecting onto a direction in the residual stream, and the same is true for getting the logit diff.***\n", "\n", "
\n", "If it's not clear what is meant by this statement, read this dropdown.\n", "\n", "Suppose our final value in the residual stream for a single sequence and a position within that sequence is $x$ (i.e. $x$ is a vector of length $d_{model}$). Then (ignoring layernorm - see the point above for why it's okay to do this), we get logits by multiplying by the unembedding matrix $W_U$ (which has shape $(d_{model}, d_{vocab})$):\n", "\n", "$$\n", "\\text{output} = x^T W_U\n", "$$\n", "\n", "Now, remember that we want the logit diff, which is $\\text{output}_{IO} - \\text{output}_{S}$ (the difference between the logits for our indirect object and subject). We can write this as:\n", "\n", "$$\n", "\\text{logit diff} = (x^T W_U)_{IO} - (x^T W_U)_{S} = x^T (u_{IO} - u_{S})\n", "$$\n", "\n", "where $u_{IO}$ and $u_S$ are the **columns of the unembedding matrix** $W_U$ corresponding to the indirect object and subject tokens respectively.\n", "\n", "To summarize, we've written the logit diff as a dot product between the vector in the residual stream and a constant vector (which is a function of the model's unembedding matrix). We call this vector $u_{IO} - u_{S}$ the **logit difference direction** (because it *\"points in the direction of largest logit difference\"*). To put it another way, if $x$ is a vector of fixed magnitude, then it maximises the logit difference when it is pointing in the same direction as the vector $u_{IO} - u_{S}$. We use the term \"projection\" synonymously with \"dot product\" here.\n", "\n", "(If you've completed the exercise where we interpret a transformer on balanced / unbalanced bracket strings, this is basically the same principle. The only difference here is that we actually have a much larger unembedding vocabulary than just the classifications `{balanced, unbalanced}`, but since we're only interested in comparing the model's prediction for IO vs S, and the logits for these two tokens are usually larger than most others, this method is still well-justified).\n", "
\n", "\n", "We use `model.tokens_to_residual_directions` to map the answer tokens to that direction, and then convert this to a logit difference direction for each batch" ] }, { "cell_type": "code", "execution_count": 32, "metadata": { "id": "kRtTa_QIHu9B", "outputId": "8d305924-fa5b-47d1-95e4-704e824abd1d", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Answer residual directions shape: torch.Size([8, 2, 768])\n", "Logit difference directions shape: torch.Size([8, 768])\n" ] } ], "source": [ "answer_residual_directions = model.tokens_to_residual_directions(answer_tokens) # [batch 2 d_model]\n", "print(\"Answer residual directions shape:\", answer_residual_directions.shape)\n", "\n", "correct_residual_directions, incorrect_residual_directions = answer_residual_directions.unbind(dim=1)\n", "logit_diff_directions = correct_residual_directions - incorrect_residual_directions # [batch d_model]\n", "print(\"Logit difference directions shape:\", logit_diff_directions.shape)" ] }, { "cell_type": "markdown", "metadata": { "id": "ht6zHAndHu9B" }, "source": [ "To verify that this works, we can apply this to the final residual stream for our cached prompts (after applying LayerNorm scaling) and verify that we get the same answer.\n", "\n", "
Technical details\n", "\n", "`logits = Unembed(LayerNorm(final_residual_stream))`, so we technically need to account for the centering, and then learned translation and scaling of the layernorm, not just the variance 1 scaling.\n", "\n", "The centering is accounted for with the preprocessing flag `center_writing_weights` which ensures that every weight matrix writing to the residual stream has mean zero.\n", "\n", "The learned scaling is folded into the unembedding weights `model.unembed.W_U` via `W_U_fold = layer_norm.weights[:, None] * unembed.W_U`\n", "\n", "The learned translation is folded to `model.unembed.b_U`, a bias added to the logits (note that GPT-2 is not trained with an existing `b_U`). This roughly represents unigram statistics. But we can ignore this because each prompt occurs twice with names in the opposite order, so this perfectly cancels out.\n", "\n", "Note that rather than using layernorm scaling we could just study cache[\"ln_final.hook_normalised\"]\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "0NJ3VYcXHu9B" }, "source": [ "The code below does the following:\n", "\n", "* Gets the final residual stream values from the `cache` object (which you should already have defined above).\n", "* Apply layernorm scaling to these values.\n", " * This is done by `cache.apply_to_ln_stack`, a helpful function which takes a stack of residual stream values (e.g. a batch, or the residual stream decomposed into components), treats them as the input to a specific layer, and applies the layer norm scaling of that layer to them.\n", " * The keyword arguments here indicate that our input is the residual stream values for the last sequence position, and we want to apply the final layernorm in the model.\n", "* Project them along the unembedding directions (you've already defined these above, as `logit_diff_directions`)." ] }, { "cell_type": "code", "execution_count": 33, "metadata": { "id": "zZMhWjBVHu9I", "outputId": "609dc4ba-161f-416c-bf5f-8d8d3a5508a8", "colab": { "base_uri": "https://localhost:8080/" } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Final residual stream shape: torch.Size([8, 15, 768])\n", "Calculated average logit diff: 3.5518774986\n", "Original logit difference: 3.5518784523\n" ] } ], "source": [ "# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].\n", "final_residual_stream: Float[Tensor, \"batch seq d_model\"] = cache[\"resid_post\", -1]\n", "print(f\"Final residual stream shape: {final_residual_stream.shape}\")\n", "final_token_residual_stream: Float[Tensor, \"batch d_model\"] = final_residual_stream[:, -1, :]\n", "\n", "# Apply LayerNorm scaling (to just the final sequence position)\n", "# pos_slice is the subset of the positions we take - here the final token of each prompt\n", "scaled_final_token_residual_stream = cache.apply_ln_to_stack(final_token_residual_stream, layer=-1, pos_slice=-1)\n", "\n", "average_logit_diff = einops.einsum(\n", " scaled_final_token_residual_stream, logit_diff_directions, \"batch d_model, batch d_model ->\"\n", ") / len(prompts)\n", "\n", "print(f\"Calculated average logit diff: {average_logit_diff:.10f}\")\n", "print(f\"Original logit difference: {original_average_logit_diff:.10f}\")\n", "\n", "t.testing.assert_close(average_logit_diff, original_average_logit_diff)" ] }, { "cell_type": "markdown", "metadata": { "id": "b83i5IWRHu9I" }, "source": [ "## Logit Lens" ] }, { "cell_type": "markdown", "metadata": { "id": "S3m8NGvwHu9I" }, "source": [ "We can now decompose the residual stream! First we apply a technique called the [**logit lens**](https://www.alignmentforum.org/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens) - this looks at the residual stream after each layer and calculates the logit difference from that. This simulates what happens if we delete all subsequence layers." ] }, { "cell_type": "markdown", "metadata": { "id": "CbCh7LM3Hu9I" }, "source": [ "### Exercise - implement `residual_stack_to_logit_diff`\n", "\n", "> ```yaml\n", "> Difficulty: 🔴🔴🔴⚪⚪\n", "> Importance: 🔵🔵🔵⚪⚪\n", ">\n", "> You should spend up to 10-15 minutes on this exercise.\n", "> Again, make sure you understand what the output of this function represents.\n", "> ```\n", "\n", "This function should look a lot like your code immediately above. `residual_stack` is a tensor of shape `(..., batch, d_model)` containing the residual stream values for the final sequence position. You should apply the final layernorm to these values, then project them in the logit difference directions." ] }, { "cell_type": "code", "execution_count": 37, "metadata": { "id": "v-BXVJhIHu9I" }, "outputs": [], "source": [ "def residual_stack_to_logit_diff(\n", " residual_stack: Float[Tensor, \"... batch d_model\"],\n", " cache: ActivationCache,\n", " logit_diff_directions: Float[Tensor, \"batch d_model\"] = logit_diff_directions,\n", ") -> Float[Tensor, \"...\"]:\n", " \"\"\"\n", " Gets the avg logit difference between the correct and incorrect answer for a given stack of components in the\n", " residual stream.\n", " \"\"\"\n", " batch_size = residual_stack.size(-2)\n", " scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)\n", " return (\n", " einops.einsum(scaled_residual_stack, logit_diff_directions, \"... batch d_model, batch d_model -> ...\")\n", " / batch_size\n", " )\n", "\n", "\n", "\n", "# Test function by checking that it gives the same result as the original logit difference\n", "t.testing.assert_close(residual_stack_to_logit_diff(final_token_residual_stream, cache), original_average_logit_diff)" ] }, { "cell_type": "markdown", "metadata": { "id": "g-T9bXAtHu9I" }, "source": [ "
Solution\n", "\n", "```python\n", "def residual_stack_to_logit_diff(\n", " residual_stack: Float[Tensor, \"... batch d_model\"],\n", " cache: ActivationCache,\n", " logit_diff_directions: Float[Tensor, \"batch d_model\"] = logit_diff_directions,\n", ") -> Float[Tensor, \"...\"]:\n", " \"\"\"\n", " Gets the avg logit difference between the correct and incorrect answer for a given stack of components in the\n", " residual stream.\n", " \"\"\"\n", " batch_size = residual_stack.size(-2)\n", " scaled_residual_stack = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)\n", " return (\n", " einops.einsum(scaled_residual_stack, logit_diff_directions, \"... batch d_model, batch d_model -> ...\")\n", " / batch_size\n", " )\n", "```\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "DtXonZBBHu9I" }, "source": [ "Once you have the solution, you can plot your results.\n", "\n", "
Details on accumulated_resid\n", "\n", "Key for the plot below: `n_pre` means the residual stream at the start of layer n, `n_mid` means the residual stream after the attention part of layer n (`n_post` is the same as `n+1_pre` so is not included)\n", "\n", "* `layer` is the layer for which we input the residual stream (this is used to identify *which* layer norm scaling factor we want)\n", "* `incl_mid` is whether to include the residual stream in the middle of a layer, ie after attention & before MLP\n", "* `pos_slice` is the subset of the positions used. See `utils.Slice` for details on the syntax.\n", "* `return_labels` is whether to return the labels for each component returned (useful for plotting)\n", "
" ] }, { "cell_type": "code", "execution_count": 38, "metadata": { "id": "jxrPQgJ6Hu9I", "outputId": "f78c3bff-fa82-42c4-f762-568cf5053793", "colab": { "base_uri": "https://localhost:8080/", "height": 542 } }, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", "\n", "\n", "
\n", "
\n", "\n", "" ] }, "metadata": {} } ], "source": [ "accumulated_residual, labels = cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)\n", "# accumulated_residual has shape (component, batch, d_model)\n", "\n", "logit_lens_logit_diffs: Float[Tensor, \"component\"] = residual_stack_to_logit_diff(accumulated_residual, cache)\n", "\n", "line(\n", " logit_lens_logit_diffs,\n", " hovermode=\"x unified\",\n", " title=\"Logit Difference From Accumulated Residual Stream\",\n", " labels={\"x\": \"Layer\", \"y\": \"Logit Diff\"},\n", " xaxis_tickvals=labels,\n", " width=800,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "xqXPvhMTHu9I" }, "source": [ "
\n", "Question - what is the interpretation of this plot? What does this tell you about how the model solves this task?\n", "\n", "Fascinatingly, we see that the model is utterly unable to do the task until layer 7, almost all performance comes from attention layer 9, and performance actually *decreases* from there.\n", "\n", "This tells us that there must be something going on (primarily in layers 7, 8 and 9) which writes to the residual stream in the correct way to solve the IOI task. This allows us to narrow in our focus, and start asking questions about what kind of computation is going on in those layers (e.g. the contribution of attention layers vs MLPs, and which attention heads are most important).\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "sUkI2kBzHu9I" }, "source": [ "## Layer Attribution" ] }, { "cell_type": "markdown", "metadata": { "id": "2ae_dSkDHu9I" }, "source": [ "We can repeat the above analysis but for each layer (this is equivalent to the differences between adjacent residual streams)\n", "\n", "Note: Annoying terminology overload - layer k of a transformer means the kth **transformer block**, but each block consists of an **attention layer** (to move information around) *and* an **MLP layer** (to process information)." ] }, { "cell_type": "code", "execution_count": 39, "metadata": { "id": "cZMSjkh5Hu9I", "outputId": "7e9d0d54-b020-49b5-a0da-5c30372b8782", "colab": { "base_uri": "https://localhost:8080/", "height": 542 } }, "outputs": [ { "output_type": "display_data", "data": { "text/html": [ "\n", "\n", "\n", "
\n", "
\n", "\n", "" ] }, "metadata": {} } ], "source": [ "per_layer_residual, labels = cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)\n", "per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, cache)\n", "\n", "line(\n", " per_layer_logit_diffs,\n", " hovermode=\"x unified\",\n", " title=\"Logit Difference From Each Layer\",\n", " labels={\"x\": \"Layer\", \"y\": \"Logit Diff\"},\n", " xaxis_tickvals=labels,\n", " width=800,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "VPlnY_D7Hu9I" }, "source": [ "
\n", "Question - what is the interpretation of this plot? What does this tell you about how the model solves this task?\n", "\n", "We see that only attention layers matter, which makes sense! The IOI task is about moving information around (i.e. moving the correct name and not the incorrect name), and less about processing it. And again we note that attention layer 9 improves things a lot, while attention 10 and attention 11 *decrease* performance.\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "Wsmhp_zaHu9I" }, "source": [ "## Head Attribution" ] }, { "cell_type": "markdown", "metadata": { "id": "k3MRRO08Hu9I" }, "source": [ "We can further break down the output of each attention layer into the sum of the outputs of each attention head. Each attention layer consists of 12 heads, which each act independently and additively.\n", "\n", "
Decomposing attention output into sums of heads\n", "\n", "The standard way to compute the output of an attention layer is by concatenating the mixed values of each head, and multiplying by a big output weight matrix. But as described in [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) this is equivalent to splitting the output weight matrix into a per-head output (here `model.blocks[k].attn.W_O`) and adding them up (including an overall bias term for the entire layer).\n", "
" ] }, { "cell_type": "code", "execution_count": 43, "metadata": { "id": "C6S3wNJyHu9I", "outputId": "fed0fff6-83c2-49d1-c384-24191e961e33", "colab": { "base_uri": "https://localhost:8080/", "height": 577 } }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "torch.Size([144, 8, 768])\n", "torch.Size([12, 12, 8, 768])\n" ] }, { "output_type": "display_data", "data": { "text/html": [ "\n", "\n", "\n", "
\n", "
\n", "\n", "" ] }, "metadata": {} } ], "source": [ "per_head_residual, labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)\n", "print(per_head_residual.shape)\n", "per_head_residual = einops.rearrange(per_head_residual, \"(layer head) ... -> layer head ...\", layer=model.cfg.n_layers)\n", "print(per_head_residual.shape)\n", "per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, cache)\n", "\n", "fig = imshow(\n", " per_head_logit_diffs,\n", " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", " title=\"Logit Difference From Each Head\",\n", " width=600,\n", " return_fig=True,\n", ")\n", "\n", "fig.write_html(section_dir / \"14103.html\")\n", "fig.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "2OvuqtN9Hu9I" }, "source": [ "We see that only a few heads really matter - heads 9.6 and 9.9 contribute a lot positively (explaining why attention layer 9 is so important), while heads 10.7 and 11.10 contribute a lot negatively (explaining why attention layer 10 and layer 11 are actively harmful). These correspond to (some of) the name movers and negative name movers discussed in the paper. There are also several heads that matter positively or negatively but less strongly (other name movers and backup name movers).\n", "\n", "There are a few meta observations worth making here - our model has 144 heads, yet we could localise this behaviour to a handful of specific heads, using straightforward, general techniques. This supports the claim in [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) that attention heads are the right level of abstraction to understand attention. It also really surprising that there are *negative* heads - eg 10.7 makes the incorrect logit 7x *more* likely. I'm not sure what's going on there, though the paper discusses some possibilities." ] }, { "cell_type": "markdown", "metadata": { "id": "5TMCJz7pHu9I" }, "source": [ "## Recap of useful functions from this section\n", "\n", "Here, we take stock of all the functions from transformerlens which you might not have seen previously.\n", "\n", "* `cache.apply_ln_to_stack`\n", " * Apply layernorm scaling to a stack of residual stream values.\n", " * We used this to help us go from \"final value in residual stream\" to \"projection of logits in logit difference directions\", without getting the code too messy!\n", "* `cache.accumulated_resid(layer=None)`\n", " * Returns the accumulated residual stream up to layer `layer` (or up to the final value of residual stream if layer is None), i.e. a stack of previous residual streams up to that layer's input.\n", " * Useful when studying the **logit lens**.\n", " * First dimension of output is `(0_pre, 0_mid, 1_pre, 1_mid, ..., final_post)`\n", "* `cache.decompose_resid(layer)`.\n", " * Decomposes the residual stream input to layer `layer` into a stack of the output of previous layers. The sum of these is the input to layer `layer`.\n", " * First dimension of output is `(embed, pos_embed, 0_attn_out, 0_mlp_out, ...)`.\n", "* `cache.stack_head_results(layer)`\n", " * Returns a stack of all head results (i.e. residual stream contribution) up to layer `layer`\n", " * (i.e. like `decompose_resid` except it splits each attention layer by head rather than splitting each layer by attention/MLP)\n", " * First dimension of output is `layer * head` (we needed to rearrange to `(layer, head)` to plot it)." ] }, { "cell_type": "markdown", "metadata": { "id": "ZeRlxZECHu9I" }, "source": [ "## Attention Analysis\n", "\n", "Attention heads are particularly fruitful to study because we can look directly at their attention patterns and study from what positions they move information from and to. This is particularly useful here as we're looking at the direct effect on the logits so we need only look at the attention patterns from the final token.\n", "\n", "We use the `circuitsvis` library (developed from Anthropic's PySvelte library) to visualize the attention patterns! We visualize the top 3 positive and negative heads by direct logit attribution, and show these for the first prompt (as an illustration).\n", "\n", "
Interpreting Attention Patterns\n", "\n", "A common mistake to make when looking at attention patterns is thinking that they must convey information about the *token* looked at (maybe accounting for the context of the token). But actually, all we can confidently say is that it moves information from the *residual stream position* corresponding to that input token. Especially later on in the model, there may be components in the residual stream that are nothing to do with the input token! Eg the period at the end of a sentence may contain summary information for that sentence, and the head may solely move that, rather than caring about whether it ends in \".\", \"!\" or \"?\"\n", "
" ] }, { "cell_type": "code", "execution_count": 41, "metadata": { "id": "qTQzQilsHu9I", "outputId": "36fe3f0f-8091-45e6-ab6a-5c3fd2de35ea", "colab": { "base_uri": "https://localhost:8080/", "height": 768 } }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "

Top 3 Positive Logit Attribution Heads

" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " " ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "

Top 3 Negative Logit Attribution Heads

" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "" ], "text/html": [ "
\n", " " ] }, "metadata": {} } ], "source": [ "def topk_of_Nd_tensor(tensor: Float[Tensor, \"rows cols\"], k: int):\n", " \"\"\"\n", " Helper function: does same as tensor.topk(k).indices, but works over 2D tensors.\n", " Returns a list of indices, i.e. shape [k, tensor.ndim].\n", "\n", " Example: if tensor is 2D array of values for each head in each layer, this will\n", " return a list of heads.\n", " \"\"\"\n", " i = t.topk(tensor.flatten(), k).indices\n", " return np.array(np.unravel_index(utils.to_numpy(i), tensor.shape)).T.tolist()\n", "\n", "\n", "k = 3\n", "\n", "for head_type in [\"Positive\", \"Negative\"]:\n", " # Get the heads with largest (or smallest) contribution to the logit difference\n", " top_heads = topk_of_Nd_tensor(per_head_logit_diffs * (1 if head_type == \"Positive\" else -1), k)\n", "\n", " # Get all their attention patterns\n", " attn_patterns_for_important_heads: Float[Tensor, \"head q k\"] = t.stack(\n", " [cache[\"pattern\", layer][:, head][0] for layer, head in top_heads]\n", " )\n", "\n", " # Display results\n", " display(HTML(f\"

Top {k} {head_type} Logit Attribution Heads

\"))\n", " display(\n", " cv.attention.attention_patterns(\n", " attention=attn_patterns_for_important_heads,\n", " tokens=model.to_str_tokens(tokens[0]),\n", " attention_head_names=[f\"{layer}.{head}\" for layer, head in top_heads],\n", " )\n", " )" ] }, { "cell_type": "markdown", "metadata": { "id": "8M2IGtqBHu9I" }, "source": [ "Reminder - you can use `attention_patterns` or `attention_heads` for these visuals. The former lets you see the actual values, the latter lets you hover over tokens in a printed sentence (and it provides other useful features like locking on tokens, or a superposition of all heads in the display). Both can be useful in different contexts (although I'd recommend usually using `attention_patterns`, it's more useful in most cases for quickly getting a sense of attention patterns).\n", "\n", "Try replacing `attention_patterns` above with `attention_heads`, and compare the output.\n", "\n", "
\n", "Help - my attention_heads plots are behaving weirdly.\n", "\n", "This seems to be a bug in `circuitsvis` - on VSCode, the attention head plots continually shrink in size.\n", "\n", "Until this is fixed, one way to get around it is to open the plots in your browser. You can do this inline with the `webbrowser` library:\n", "\n", "```python\n", "attn_heads = cv.attention.attention_heads(\n", " attention = attn_patterns_for_important_heads,\n", " tokens = model.to_str_tokens(tokens[0]),\n", " attention_head_names = [f\"{layer}.{head}\" for layer, head in top_heads],\n", ")\n", "\n", "path = \"attn_heads.html\"\n", "\n", "with open(path, \"w\") as f:\n", " f.write(str(attn_heads))\n", "\n", "webbrowser.open(path)\n", "```\n", "\n", "To check exactly where this is getting saved, you can print your current working directory with `os.getcwd()`.\n", "
\n", "\n", "From these plots, you might want to start thinking about the algorithm which is being implemented. In particular, for the attention heads with high positive attribution scores, where is `\" to\"` attending to? How might this head be affecting the logit diff score?\n", "\n", "We'll save a full hypothesis for how the model works until the end of the next section." ] }, { "cell_type": "markdown", "metadata": { "id": "JrdvPWDYHu9I" }, "source": [ "# 3️⃣ Activation Patching\n", "\n", "> ##### Learning Objectives\n", ">\n", "> * Understand the idea of activation patching, and how it can be used\n", "> * Implement some of the activation patching helper functinos in transformerlens from scratch (i.e. using hooks)\n", "> * Use activation patching to track the layers & sequence positions in the residual stream where important information is stored and processed\n", "> * By the end of this section, you should be able to draw a rough sketch of the IOI circuit" ] }, { "cell_type": "markdown", "metadata": { "id": "qrzh2XkaHu9I" }, "source": [ "## Introduction" ] }, { "cell_type": "markdown", "metadata": { "id": "M-qhLcZmHu9J" }, "source": [ "The obvious limitation to the techniques used above is that they only look at the very end of the circuit - the parts that directly affect the logits. Clearly this is not sufficient to understand the circuit! We want to understand how things compose together to produce this final output, and ideally to produce an end-to-end circuit fully explaining this behaviour.\n", "\n", "The technique we'll use to investigate this is called **activation patching**. This was first introduced in [David Bau and Kevin Meng's excellent ROME paper](https://rome.baulab.info/), there called causal tracing.\n", "\n", "The setup of activation patching is to take two runs of the model on two different inputs, the clean run and the corrupted run. The clean run outputs the correct answer and the corrupted run does not. The key idea is that we give the model the corrupted input, but then **intervene** on a specific activation and **patch** in the corresponding activation from the clean run (ie replace the corrupted activation with the clean activation), and then continue the run. And we then measure how much the output has updated towards the correct answer.\n", "\n", "We can then iterate over many possible activations and look at how much they affect the corrupted run. If patching in an activation significantly increases the probability of the correct answer, this allows us to *localise* which activations matter.\n", "\n", "In other words, this is a **noising** algorithm (unlike last section which was mostly **denoising**).\n", "\n", "The ability to localise is a key move in mechanistic interpretability - if the computation is diffuse and spread across the entire model, it is likely much harder to form a clean mechanistic story for what's going on. But if we can identify precisely which parts of the model matter, we can then zoom in and determine what they represent and how they connect up with each other, and ultimately reverse engineer the underlying circuit that they represent." ] }, { "cell_type": "markdown", "metadata": { "id": "z7kuubZLHu9J" }, "source": [ "The diagrams below demonstrate activation patching on an abstract neural network (the nodes represent activations, and the arrows between them are weight connections).\n", "\n", "A regular forward pass on the clean input looks like:\n", "\n", "\n", "\n", "And activation patching from a corrupted input (green) into a forward pass for the clean input (black) looks like:\n", "\n", "\n", "\n", "where the dotted line represents patching in a value (i.e. during the forward pass on the clean input, we replace node $D$ with the value it takes on the corrupted input). Nodes $H$, $G$ and $F$ are colored orange, to represent that they now follow a distribution which is not the same as clean or corrupted." ] }, { "cell_type": "markdown", "metadata": { "id": "qPh6uIzwHu9J" }, "source": [ "We can patch into a transformer in many different ways (e.g. values of the residual stream, the MLP, or attention heads' output - see below). We can also get even more granular by patching at particular sequence positions (not shown in diagram).\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "QR8QjZk5Hu9J" }, "source": [ "### Noising vs denoising\n", "\n", "We might call this algorithm a type of **noising**, since we're running the model on a clean input and adding noise by patching in from the corrupted input. We can also consider the opposite algorithm, **denoising**, where we run the model on a corrupted input and remove noise by patching in from the clean input.\n", "\n", "When would you use noising vs denoising? It depends on your goals. The results of denoising are much stronger, because showing that a component or set of components is sufficient for a task is a big deal. On the other hand, the complexity of transformers and interdependence of components means that noising a model can have unpredictable consequences. If loss goes up when we ablate a component, it doesn't necessarily mean that this component was necessary for the task. As an example, ablating MLP0 in gpt2-small seems to make performance much worse on basically any task (because it acts as a kind of extended embedding; more on this later in these exercises), but it's not doing anything important which is *specfic* for the IOI task." ] }, { "cell_type": "markdown", "metadata": { "id": "zbXx18P9Hu9J" }, "source": [ "### Example: denoising the residual stream\n", "\n", "The above was all fairly abstract, so let's zoom in and lay out a concrete example to understand Indirect Object Identification. We'll start with an exercise on denoising, but we'll move onto noising later in this section (and the next section, on path patching).\n", "\n", "Here our clean input will be the original sentences (e.g. \"When Mary and John went to the store, John gave a drink to\") and our corrupted input will have the subject token flipped (e.g. \"When Mary and John went to the store, Mary gave a drink to\"). Patching by replacing corrupted residual stream values with clean values is a causal intervention which will allow us to understand precisely which parts of the network are identifying the indirect object. If a component is important, then patching in (replacing that component's corrupted output with its clean output) will reverse the signal that this component produces, hence making performance much better.\n", "\n", "Note - the noising and denoising terminology doesn't exactly fit here, since the \"noised dataset\" actually **reverses** the signal rather than erasing it. The reason we're describing this as denoising is more a matter of framing - we're trying to figure out which components / activations are **sufficient** to recover performance, rather than which are **necessary**. If you're ever confused, this is a useful framing to have - **noising tells you what is necessary, denoising tells you what is sufficient.**" ] }, { "cell_type": "markdown", "metadata": { "id": "rW51svu_Hu9J" }, "source": [ "Question - we could instead have our corrupted sentence be \"When John and Mary went to the store, Mary gave a drink to\" (i.e. flip all 3 occurrences of names in the sentence). Why do you think we don't do this?\n", "\n", "
\n", "Hint\n", "\n", "What if, at some point during the model's forward pass on the prompt `\"When Mary and John went to the store, John gave a drink to\"`, it contains some representation of the information **\"the indirect object is the fourth token in this sequence\"**?\n", "
\n", "\n", "
\n", "Answer\n", "\n", "The model could point to the indirect object `' Mary'` in two different ways:\n", "\n", "* Via **token information**, i.e. **\"the indirect object is the token `' Mary'`\"**.\n", "* Via **positional information**, i.e. **\"the indirect object is the fourth token in this sequence\"**.\n", "\n", "We want the corrupted dataset to reverse both these signals when it's patched into the clean dataset. But if we corrupted the dataset by flipping all three names, then:\n", "\n", "* The token information is flipped, because the corresponding information in the model for the corrupted prompt will be **\"the indirect object is the token `' Mary'`\"**.\n", "* The positional information is ***not*** flipped, because the corresponding information will still be **\"the indirect object is the fourth token in this sequence\"**.\n", "\n", "In fact, in the bonus section we'll take advantage of this fact to try and disentangle whether token or positional information is being used by the model (i.e. by flipping the token information but not the positional information, and vice-versa). Spoiler alert - it turns out to be using a bit of both!\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "8YLxGcs_Hu9J" }, "source": [ "One natural thing to patch in is the residual stream at a specific layer and specific position. For example, the model is likely intitially doing some processing on the `S2` token to realise that it's a duplicate, but then uses attention to move that information to the `end` token. So patching in the residual stream at the `end` token will likely matter a lot in later layers but not at all in early layers.\n", "\n", "We can zoom in much further and patch in specific activations from specific layers. For example, we think that the output of head 9.9 on the final token is significant for directly connecting to the logits, so we predict that just patching the output of this head will significantly affect performance.\n", "\n", "Note that this technique does *not* tell us how the components of the circuit connect up, just what they are." ] }, { "cell_type": "markdown", "metadata": { "id": "4h8zucuBHu9J" }, "source": [ "TransformerLens has helpful built-in functions to perform activation patching, but in order to understand the process better, you're now going to implement some of these functions from first principles (i.e. just using hooks). You'll be able to test your functions by comparing their output to the built-in functions.\n", "\n", "If you need a refresher on hooks, you can return to the exercises on induction heads (which take you through how to use hooks, as well as how to cache activations)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bzloSSkgHu9J" }, "outputs": [], "source": [ "from transformer_lens import patching" ] }, { "cell_type": "markdown", "metadata": { "id": "mRrxV9gVHu9J" }, "source": [ "## Creating a metric" ] }, { "cell_type": "markdown", "metadata": { "id": "E310LeZMHu9J" }, "source": [ "Before we patch, we need to create a metric for evaluating a set of logits. Since we'll be running our **corrupted prompts** (with `S2` replaced with the wrong name) and patching in our **clean prompts**, it makes sense to choose a metric such that:\n", "\n", "* A value of zero means no change (from the performance on the corrupted prompt)\n", "* A value of one means clean performance has been completely recovered\n", "\n", "For example, if we patched in the entire clean prompt, we'd get a value of one. If our patching actually makes the model even better at solving the task than its regular behaviour on the clean prompt then we'd get a value greater than 1, but generally we expect values between 0 and 1.\n", "\n", "It also makes sense to have the metric be a linear function of the logit difference. This is enough to uniquely specify a metric." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "nU-jT_3RHu9J" }, "outputs": [], "source": [ "clean_tokens = tokens\n", "# Swap each adjacent pair to get corrupted tokens\n", "indices = [i + 1 if i % 2 == 0 else i - 1 for i in range(len(tokens))]\n", "corrupted_tokens = clean_tokens[indices]\n", "\n", "print(\n", " \"Clean string 0: \",\n", " model.to_string(clean_tokens[0]),\n", " \"\\nCorrupted string 0:\",\n", " model.to_string(corrupted_tokens[0]),\n", ")\n", "\n", "clean_logits, clean_cache = model.run_with_cache(clean_tokens)\n", "corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)\n", "\n", "clean_logit_diff = logits_to_ave_logit_diff(clean_logits, answer_tokens)\n", "print(f\"Clean logit diff: {clean_logit_diff:.4f}\")\n", "\n", "corrupted_logit_diff = logits_to_ave_logit_diff(corrupted_logits, answer_tokens)\n", "print(f\"Corrupted logit diff: {corrupted_logit_diff:.4f}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "woP2tFuGHu9J" }, "source": [ "### Exercise - create a metric\n", "\n", "> ```yaml\n", "> Difficulty: 🔴🔴⚪⚪⚪\n", "> Importance: 🔵🔵🔵⚪⚪\n", ">\n", "> You should spend up to ~10 minutes on this exercise.\n", "> ```\n", "\n", "Fill in the function `ioi_metric` below, to create the required metric. Note that we can afford to use default arguments in this function, because we'll be using the same dataset for this whole section.\n", "\n", "**Important note** - this function needs to return a scalar tensor, rather than a float. If not, then some of the patching functions later on won't work. The type signature of this is `Float[Tensor, \"\"]`.\n", "\n", "**Second important note** - we've defined this to be 0 when performance is the same as on corrupted input, and 1 when it's the same as on clean input. This is because we're performing a **denoising algorithm**; we're looking for activations which are sufficient for recovering a model's performance (i.e. activations which have enough information to recover the correct answer from the corrupted input). Our \"null hypothesis\" is that the component isn't sufficient, and so patching it by replacing corrupted with clean values doesn't recover any performance. In later sections we'll be doing noising, and we'll define a new metric function for that." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aKKrB_BvHu9J" }, "outputs": [], "source": [ "def ioi_metric(\n", " logits: Float[Tensor, \"batch seq d_vocab\"],\n", " answer_tokens: Float[Tensor, \"batch 2\"] = answer_tokens,\n", " corrupted_logit_diff: float = corrupted_logit_diff,\n", " clean_logit_diff: float = clean_logit_diff,\n", ") -> Float[Tensor, \"\"]:\n", " \"\"\"\n", " Linear function of logit diff, calibrated so that it equals 0 when performance is same as on corrupted input, and 1\n", " when performance is same as on clean input.\n", " \"\"\"\n", " raise NotImplementedError()\n", "\n", "\n", "t.testing.assert_close(ioi_metric(clean_logits).item(), 1.0)\n", "t.testing.assert_close(ioi_metric(corrupted_logits).item(), 0.0)\n", "t.testing.assert_close(ioi_metric((clean_logits + corrupted_logits) / 2).item(), 0.5)" ] }, { "cell_type": "markdown", "metadata": { "id": "DXlZ0QEaHu9J" }, "source": [ "
Solution\n", "\n", "```python\n", "def ioi_metric(\n", " logits: Float[Tensor, \"batch seq d_vocab\"],\n", " answer_tokens: Float[Tensor, \"batch 2\"] = answer_tokens,\n", " corrupted_logit_diff: float = corrupted_logit_diff,\n", " clean_logit_diff: float = clean_logit_diff,\n", ") -> Float[Tensor, \"\"]:\n", " \"\"\"\n", " Linear function of logit diff, calibrated so that it equals 0 when performance is same as on corrupted input, and 1\n", " when performance is same as on clean input.\n", " \"\"\"\n", " patched_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens)\n", " return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)\n", "\n", "\n", "t.testing.assert_close(ioi_metric(clean_logits).item(), 1.0)\n", "t.testing.assert_close(ioi_metric(corrupted_logits).item(), 0.0)\n", "t.testing.assert_close(ioi_metric((clean_logits + corrupted_logits) / 2).item(), 0.5)\n", "```\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "MZBhVLphHu9J" }, "source": [ "## Residual Stream Patching" ] }, { "cell_type": "markdown", "metadata": { "id": "eOOmJoVFHu9J" }, "source": [ "Lets begin with a simple example: we patch in the residual stream at the start of each layer and for each token position. Before you write your own function to do this, let's see what this looks like with TransformerLens' `patching` module. Run the code below." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "c12fnvrqHu9J" }, "outputs": [], "source": [ "act_patch_resid_pre = patching.get_act_patch_resid_pre(\n", " model=model, corrupted_tokens=corrupted_tokens, clean_cache=clean_cache, patching_metric=ioi_metric\n", ")\n", "\n", "labels = [f\"{tok} {i}\" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "0dTQ5XM9Hu9J" }, "outputs": [], "source": [ "imshow(\n", " act_patch_resid_pre,\n", " labels={\"x\": \"Position\", \"y\": \"Layer\"},\n", " x=labels,\n", " title=\"resid_pre Activation Patching\",\n", " width=600\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "qM1GviFNHu9J" }, "source": [ "Question - what is the interpretation of this graph? What significant things does it tell you about the nature of how the model solves this task?\n", "\n", "
\n", "Hint\n", "\n", "Think about locality of computation.\n", "
\n", "\n", "
\n", "Answer\n", "\n", "Originally all relevant computation happens on `S2`, and at layers 7 and 8, the information is moved to `END`. Moving the residual stream at the correct position near *exactly* recovers performance!\n", "\n", "To be clear, the striking thing about this graph isn't that the first row is zero everywhere except for `S2` where it is 1, or that the rows near the end trend to being zero everywhere except for `END` where they are 1; both of these are exactly what we'd expect. The striking things are:\n", "\n", "* The computation is highly localized; the relevant information for choosing `IO` over `S` is initially stored in `S2` token and then moved to `END` token without taking any detours.\n", "* The model is basically done after layer 8, and the rest of the layers actually slightly impede performance on this particular task.\n", "\n", "(Note - for reference, tokens and their index from the first prompt are on the x-axis. In an abuse of notation, note that the difference here is averaged over *all* 8 prompts, while the labels only come from the *first* prompt.)\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "Mjb4T9C3Hu9J" }, "source": [ "### Exercise - implement head-to-residual patching\n", "\n", "> ```yaml\n", "> Difficulty: 🔴🔴🔴🔴⚪\n", "> Importance: 🔵🔵🔵🔵🔵\n", ">\n", "> You should spend up to 20-25 minutes on this exercise.\n", ">\n", "> It's very important to understand how patching works. Many subsequent exercises will build on this one.\n", "> ```\n", "\n", "Now, you should implement the `get_act_patch_resid_pre` function below, which should give you results just like the code you ran above. A quick refresher on how to use hooks in this way:\n", "\n", "* Hook functions take arguments `tensor: t.Tensor` and `hook: HookPoint`. It's often easier to define a hook function taking more arguments than these, and then use `functools.partial` when it actually comes time to add your hook.\n", "* The function `model.run_with_hooks` takes arguments:\n", " * The tokens to run (as first argument)\n", " * `fwd_hooks` - a list of `(hook_name, hook_fn)` tuples. Remember that you can use `utils.get_act_name` to get hook names.\n", "* Tip - it's good practice to have `model.reset_hooks()` at the start of functions which add and run hooks. This is because sometimes hooks fail to be removed (if they cause an error while running). There's nothing more frustrating than fixing a hook error only to get the same error message, not realising that you've failed to clear the broken hook!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RGF2Lj7KHu9J" }, "outputs": [], "source": [ "def patch_residual_component(\n", " corrupted_residual_component: Float[Tensor, \"batch pos d_model\"],\n", " hook: HookPoint,\n", " pos: int,\n", " clean_cache: ActivationCache,\n", ") -> Float[Tensor, \"batch pos d_model\"]:\n", " \"\"\"\n", " Patches a given sequence position in the residual stream, using the value\n", " from the clean cache.\n", " \"\"\"\n", " raise NotImplementedError()\n", "\n", "\n", "def get_act_patch_resid_pre(\n", " model: HookedTransformer,\n", " corrupted_tokens: Float[Tensor, \"batch pos\"],\n", " clean_cache: ActivationCache,\n", " patching_metric: Callable[[Float[Tensor, \"batch pos d_vocab\"]], float],\n", ") -> Float[Tensor, \"layer pos\"]:\n", " \"\"\"\n", " Returns an array of results of patching each position at each layer in the residual\n", " stream, using the value from the clean cache.\n", "\n", " The results are calculated using the patching_metric function, which should be\n", " called on the model's logit output.\n", " \"\"\"\n", " raise NotImplementedError()\n", "\n", "\n", "act_patch_resid_pre_own = get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, ioi_metric)\n", "\n", "t.testing.assert_close(act_patch_resid_pre, act_patch_resid_pre_own)" ] }, { "cell_type": "markdown", "metadata": { "id": "Up7b6ehVHu9J" }, "source": [ "
Solution\n", "\n", "```python\n", "def patch_residual_component(\n", " corrupted_residual_component: Float[Tensor, \"batch pos d_model\"],\n", " hook: HookPoint,\n", " pos: int,\n", " clean_cache: ActivationCache,\n", ") -> Float[Tensor, \"batch pos d_model\"]:\n", " \"\"\"\n", " Patches a given sequence position in the residual stream, using the value\n", " from the clean cache.\n", " \"\"\"\n", " corrupted_residual_component[:, pos, :] = clean_cache[hook.name][:, pos, :]\n", " return corrupted_residual_component\n", "\n", "\n", "def get_act_patch_resid_pre(\n", " model: HookedTransformer,\n", " corrupted_tokens: Float[Tensor, \"batch pos\"],\n", " clean_cache: ActivationCache,\n", " patching_metric: Callable[[Float[Tensor, \"batch pos d_vocab\"]], float],\n", ") -> Float[Tensor, \"layer pos\"]:\n", " \"\"\"\n", " Returns an array of results of patching each position at each layer in the residual\n", " stream, using the value from the clean cache.\n", "\n", " The results are calculated using the patching_metric function, which should be\n", " called on the model's logit output.\n", " \"\"\"\n", " model.reset_hooks()\n", " seq_len = corrupted_tokens.size(1)\n", " results = t.zeros(model.cfg.n_layers, seq_len, device=device, dtype=t.float32)\n", "\n", " for layer in tqdm(range(model.cfg.n_layers)):\n", " for position in range(seq_len):\n", " hook_fn = partial(patch_residual_component, pos=position, clean_cache=clean_cache)\n", " patched_logits = model.run_with_hooks(\n", " corrupted_tokens,\n", " fwd_hooks=[(utils.get_act_name(\"resid_pre\", layer), hook_fn)],\n", " )\n", " results[layer, position] = patching_metric(patched_logits)\n", "\n", " return results\n", "```\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "UVlboNMZHu9J" }, "source": [ "Once you've passed the tests, you can plot your results." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "LfIg33JVHu9J" }, "outputs": [], "source": [ "imshow(\n", " act_patch_resid_pre_own,\n", " x=labels,\n", " title=\"Logit Difference From Patched Residual Stream\",\n", " labels={\"x\": \"Sequence Position\", \"y\": \"Layer\"},\n", " width=700,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "ExZpTUoKHu9K" }, "source": [ "## Patching in residual stream by block" ] }, { "cell_type": "markdown", "metadata": { "id": "wyZ5H4_cHu9K" }, "source": [ "Rather than just patching to the residual stream in each layer, we can also patch just after the attention layer or just after the MLP. This gives is a slightly more refined view of which tokens matter and when.\n", "\n", "The function `patching.get_act_patch_block_every` works just like `get_act_patch_resid_pre`, but rather than just patching to the residual stream, it patches to `resid_pre`, `attn_out` and `mlp_out`, and returns a tensor of shape `(3, n_layers, seq_len)`.\n", "\n", "One important thing to note - we're cycling through the `resid_pre`, `attn_out` and `mlp_out` and only patching one of them at a time, rather than patching all three at once." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CbEn-9Z7Hu9K" }, "outputs": [], "source": [ "act_patch_block_every = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sVZdsHPTHu9K" }, "outputs": [], "source": [ "imshow(\n", " act_patch_block_every,\n", " x=labels,\n", " facet_col=0, # This argument tells plotly which dimension to split into separate plots\n", " facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"], # Subtitles of separate plots\n", " title=\"Logit Difference From Patched Attn Head Output\",\n", " labels={\"x\": \"Sequence Position\", \"y\": \"Layer\"},\n", " width=1200,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "xat-Lt6qHu9K" }, "source": [ "
\n", "Question - what is the interpretation of the second two plots?\n", "\n", "We see that several attention layers are significant but that, matching the residual stream results, early layers matter on `S2`, and later layers matter on `END`, and layers essentially don't matter on any other token. Extremely localised!\n", "\n", "As with direct logit attribution, layer 9 is positive and layers 10 and 11 are not, suggesting that the late layers only matter for direct logit effects, but we also see that layers 7 and 8 matter significantly. Presumably these are the heads that move information about which name is duplicated from `S2` to `END`.\n", "\n", "In contrast, the MLP layers do not matter much. This makes sense, since this is more a task about moving information than about processing it, and the MLP layers specialise in processing information. The one exception is MLP 0, which matters a lot, but I think this is misleading and just a generally true statement about MLP 0 rather than being about the circuit on this task.\n", "\n", "
My takes on MLP0\n", "\n", "It's often observed on GPT-2 Small that MLP0 matters a lot, and that ablating it utterly destroys performance. My current best guess is that the first MLP layer is essentially acting as an extension of the embedding (for whatever reason) and that when later layers want to access the input tokens they mostly read in the output of the first MLP layer, rather than the token embeddings. Within this frame, the first attention layer doesn't do much.\n", "\n", "In this framing, it makes sense that MLP0 matters on `S2`, because that's the one position with a different input token!\n", "\n", "I'm not entirely sure why this happens, but I would guess that it's because the embedding and unembedding matrices in GPT-2 Small are the same. This is pretty unprincipled, as the tasks of embedding and unembedding tokens are not inverses, but this is common practice, and plausibly models want to dedicate some parameters to overcoming this.\n", "\n", "I only have suggestive evidence of this, and would love to see someone look into this properly!\n", "
\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "uDys6oVqHu9K" }, "source": [ "### Exercise (optional) - implement head-to-block patching\n", "\n", "> ```yaml\n", "> Difficulty: 🔴🔴⚪⚪⚪\n", "> Importance: 🔵🔵⚪⚪⚪\n", ">\n", "> You should spend up to ~10 minutes on this exercise.\n", ">\n", "> Most code can be copied from the last exercise.\n", "> ```\n", "\n", "If you want, you can implement the `get_act_patch_resid_pre` function for fun, although it's similar enough to the previous exercise that doing this isn't compulsory." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-lfdCgb4Hu9K" }, "outputs": [], "source": [ "def get_act_patch_block_every(\n", " model: HookedTransformer,\n", " corrupted_tokens: Float[Tensor, \"batch pos\"],\n", " clean_cache: ActivationCache,\n", " patching_metric: Callable[[Float[Tensor, \"batch pos d_vocab\"]], float],\n", ") -> Float[Tensor, \"layer pos\"]:\n", " \"\"\"\n", " Returns an array of results of patching each position at each layer in the residual stream, using the value from the\n", " clean cache.\n", "\n", " The results are calculated using the patching_metric function, which should be called on the model's logit output.\n", " \"\"\"\n", " raise NotImplementedError()\n", "\n", "\n", "act_patch_block_every_own = get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)\n", "\n", "t.testing.assert_close(act_patch_block_every, act_patch_block_every_own)\n", "\n", "imshow(\n", " act_patch_block_every_own,\n", " x=labels,\n", " facet_col=0,\n", " facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n", " title=\"Logit Difference From Patched Attn Head Output\",\n", " labels={\"x\": \"Sequence Position\", \"y\": \"Layer\"},\n", " width=1200,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "lB1w_8MZHu9K" }, "outputs": [], "source": [ "imshow(\n", " act_patch_block_every_own,\n", " x=labels,\n", " facet_col=0,\n", " facet_labels=[\"Residual Stream\", \"Attn Output\", \"MLP Output\"],\n", " title=\"Logit Difference From Patched Attn Head Output\",\n", " labels={\"x\": \"Sequence Position\", \"y\": \"Layer\"},\n", " width=1200\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "9JfHRdCqHu9K" }, "source": [ "
Solution\n", "\n", "```python\n", "def get_act_patch_block_every(\n", " model: HookedTransformer,\n", " corrupted_tokens: Float[Tensor, \"batch pos\"],\n", " clean_cache: ActivationCache,\n", " patching_metric: Callable[[Float[Tensor, \"batch pos d_vocab\"]], float],\n", ") -> Float[Tensor, \"layer pos\"]:\n", " \"\"\"\n", " Returns an array of results of patching each position at each layer in the residual stream, using the value from the\n", " clean cache.\n", "\n", " The results are calculated using the patching_metric function, which should be called on the model's logit output.\n", " \"\"\"\n", " model.reset_hooks()\n", " results = t.zeros(3, model.cfg.n_layers, tokens.size(1), device=device, dtype=t.float32)\n", "\n", " for component_idx, component in enumerate([\"resid_pre\", \"attn_out\", \"mlp_out\"]):\n", " for layer in tqdm(range(model.cfg.n_layers)):\n", " for position in range(corrupted_tokens.shape[1]):\n", " hook_fn = partial(patch_residual_component, pos=position, clean_cache=clean_cache)\n", " patched_logits = model.run_with_hooks(\n", " corrupted_tokens,\n", " fwd_hooks=[(utils.get_act_name(component, layer), hook_fn)],\n", " )\n", " results[component_idx, layer, position] = patching_metric(patched_logits)\n", "\n", " return results\n", "```\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "g5LV_Y-9Hu9K" }, "source": [ "## Head Patching" ] }, { "cell_type": "markdown", "metadata": { "id": "vUcxA22-Hu9K" }, "source": [ "We can refine the above analysis by patching in individual heads! This is somewhat more annoying, because there are now three dimensions `(head_index, position and layer)`.\n", "\n", "The code below patches a head's output over all sequence positions, and returns the results (for each head in the model)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fHvuwRVrHu9K" }, "outputs": [], "source": [ "act_patch_attn_head_out_all_pos = patching.get_act_patch_attn_head_out_all_pos(\n", " model, corrupted_tokens, clean_cache, ioi_metric\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CXBnevp2Hu9K" }, "outputs": [], "source": [ "imshow(\n", " act_patch_attn_head_out_all_pos,\n", " labels={\"y\": \"Layer\", \"x\": \"Head\"},\n", " title=\"attn_head_out Activation Patching (All Pos)\",\n", " width=600\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "JyavGqZ-Hu9K" }, "source": [ "
\n", "Question - what are the interpretations of this graph? Which heads do you think are important?\n", "\n", "We see some of the heads that we observed in our attention plots at the end of last section (e.g. `9.9` having a large positive score, and `10.7` having a large negative score). But we can also see some other important heads, for instance:\n", "\n", "* In layers 7-8 there are several important heads. We might deduce that these are the ones responsible for moving information from `S2` to `end`.\n", "* In the earlier layers, there are some more important heads (e.g. `3.0` and `5.5`). We might guess these are performing some primitive logic, e.g. causing the second `\" John\"` token to attend to previous instances of itself.\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "b8Yg74AXHu9K" }, "source": [ "### Exercise - implement head-to-head patching\n", "\n", "> ```yaml\n", "> Difficulty: 🔴🔴🔴⚪⚪\n", "> Importance: 🔵🔵🔵🔵⚪\n", ">\n", "> You should spend up to 10-15 minutes on this exercise.\n", ">\n", "> Again, it should be similar to the first patching exercise (you can copy code).\n", "> ```\n", "\n", "You should implement your own version of this patching function below.\n", "\n", "You'll need to define a new hook function, but most of the code from the previous exercise should be reusable.\n", "\n", "
\n", "Help - I'm not sure what hook name to use for my patching.\n", "\n", "You should patch at:\n", "\n", "```python\n", "utils.get_act_name(\"z\", layer)\n", "```\n", "\n", "This is the linear combination of value vectors, i.e. it's the thing you multiply by $W_O$ before adding back into the residual stream. There's no point patching after the $W_O$ multiplication, because it will have the same effect, but take up more memory (since `d_model` is larger than `d_head`).\n", "
" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kYQG3gqRHu9K" }, "outputs": [], "source": [ "def patch_head_vector(\n", " corrupted_head_vector: Float[Tensor, \"batch pos head_index d_head\"],\n", " hook: HookPoint,\n", " head_index: int,\n", " clean_cache: ActivationCache,\n", ") -> Float[Tensor, \"batch pos head_index d_head\"]:\n", " \"\"\"\n", " Patches the output of a given head (before it's added to the residual stream) at every sequence position, using the\n", " value from the clean cache.\n", " \"\"\"\n", " raise NotImplementedError()\n", "\n", "\n", "def get_act_patch_attn_head_out_all_pos(\n", " model: HookedTransformer,\n", " corrupted_tokens: Float[Tensor, \"batch pos\"],\n", " clean_cache: ActivationCache,\n", " patching_metric: Callable,\n", ") -> Float[Tensor, \"layer head\"]:\n", " \"\"\"\n", " Returns an array of results of patching at all positions for each head in each layer, using the value from the clean\n", " cache. The results are calculated using the patching_metric function, which should be called on the model's logit\n", " output.\n", " \"\"\"\n", " raise NotImplementedError()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "P5jTcpDJHu9K" }, "outputs": [], "source": [ "act_patch_attn_head_out_all_pos_own = get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)\n", "\n", "t.testing.assert_close(act_patch_attn_head_out_all_pos, act_patch_attn_head_out_all_pos_own)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZKM_ksVbHu9K" }, "outputs": [], "source": [ "imshow(\n", " act_patch_attn_head_out_all_pos_own,\n", " title=\"Logit Difference From Patched Attn Head Output\",\n", " labels={\"x\":\"Head\", \"y\":\"Layer\"},\n", " width=600\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "_XP-8A3qHu9K" }, "source": [ "
Solution\n", "\n", "```python\n", "def patch_head_vector(\n", " corrupted_head_vector: Float[Tensor, \"batch pos head_index d_head\"],\n", " hook: HookPoint,\n", " head_index: int,\n", " clean_cache: ActivationCache,\n", ") -> Float[Tensor, \"batch pos head_index d_head\"]:\n", " \"\"\"\n", " Patches the output of a given head (before it's added to the residual stream) at every sequence position, using the\n", " value from the clean cache.\n", " \"\"\"\n", " corrupted_head_vector[:, :, head_index] = clean_cache[hook.name][:, :, head_index]\n", " return corrupted_head_vector\n", "\n", "\n", "def get_act_patch_attn_head_out_all_pos(\n", " model: HookedTransformer,\n", " corrupted_tokens: Float[Tensor, \"batch pos\"],\n", " clean_cache: ActivationCache,\n", " patching_metric: Callable,\n", ") -> Float[Tensor, \"layer head\"]:\n", " \"\"\"\n", " Returns an array of results of patching at all positions for each head in each layer, using the value from the clean\n", " cache. The results are calculated using the patching_metric function, which should be called on the model's logit\n", " output.\n", " \"\"\"\n", " model.reset_hooks()\n", " results = t.zeros(model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=t.float32)\n", "\n", " for layer in tqdm(range(model.cfg.n_layers)):\n", " for head in range(model.cfg.n_heads):\n", " hook_fn = partial(patch_head_vector, head_index=head, clean_cache=clean_cache)\n", " patched_logits = model.run_with_hooks(\n", " corrupted_tokens, fwd_hooks=[(utils.get_act_name(\"z\", layer), hook_fn)], return_type=\"logits\"\n", " )\n", " results[layer, head] = patching_metric(patched_logits)\n", "\n", " return results\n", "\n", "\n", "act_patch_attn_head_out_all_pos_own = get_act_patch_attn_head_out_all_pos(\n", " model, corrupted_tokens, clean_cache, ioi_metric\n", ")\n", "\n", "t.testing.assert_close(act_patch_attn_head_out_all_pos, act_patch_attn_head_out_all_pos_own)\n", "\n", "imshow(\n", " act_patch_attn_head_out_all_pos_own,\n", " title=\"Logit Difference From Patched Attn Head Output\",\n", " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", " width=600,\n", ")\n", "```\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "hXDbEzSjHu9K" }, "source": [ "## Decomposing Heads" ] }, { "cell_type": "markdown", "metadata": { "id": "a3tBVkbfHu9K" }, "source": [ "Finally, we'll look at one more example of activation patching.\n", "\n", "Decomposing attention layers into patching in individual heads has already helped us localise the behaviour a lot. But we can understand it further by decomposing heads. An attention head consists of two semi-independent operations - calculating *where* to move information from and to (represented by the attention pattern and implemented via the QK-circuit) and calculating *what* information to move (represented by the value vectors and implemented by the OV circuit). We can disentangle which of these is important by patching in just the attention pattern *or* the value vectors. See [A Mathematical Framework](https://transformer-circuits.pub/2021/framework/index.html) or [Neel's walkthrough video](https://www.youtube.com/watch?v=KV5gbOmHbjU) for more on this decomposition." ] }, { "cell_type": "markdown", "metadata": { "id": "hzzjWJLqHu9K" }, "source": [ "A useful function for doing this is `get_act_patch_attn_head_all_pos_every`. Rather than just patching on head output (like the previous one), it patches on:\n", "* Output (this is equivalent to patching the value the head writes to the residual stream)\n", "* Querys (i.e. the patching the query vectors, without changing the key or value vectors)\n", "* Keys\n", "* Values\n", "* Patterns (i.e. the attention patterns).\n", "\n", "Again, note that this function isn't patching multiple things at once. It's looping through each of these five, and getting the results from patching them one at a time." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jRT8eb2NHu9K" }, "outputs": [], "source": [ "act_patch_attn_head_all_pos_every = patching.get_act_patch_attn_head_all_pos_every(\n", " model, corrupted_tokens, clean_cache, ioi_metric\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "PlB2WmZ8Hu9K" }, "outputs": [], "source": [ "imshow(\n", " act_patch_attn_head_all_pos_every,\n", " facet_col=0,\n", " facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n", " title=\"Activation Patching Per Head (All Pos)\",\n", " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "1fZg-78XHu9K" }, "source": [ "### Exercise (optional) - implement head-to-head-input patching\n", "\n", "> ```yaml\n", "> Difficulty: 🔴🔴⚪⚪⚪\n", "> Importance: 🔵🔵⚪⚪⚪\n", ">\n", "> You should spend up to ~10 minutes on this exercise.\n", ">\n", "> Most code can be copied from the last exercise.\n", "> ```\n", "\n", "Again, if you want to implement this yourself then you can do so below, but it isn't a compulsory exercise because it isn't conceptually different from the previous exercises. If you don't implement it, then you should still look at the solution to make sure you understand what's going on." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "6CAnhbhXHu9K" }, "outputs": [], "source": [ "def patch_attn_patterns(\n", " corrupted_head_vector: Float[Tensor, \"batch head_index pos_q pos_k\"],\n", " hook: HookPoint,\n", " head_index: int,\n", " clean_cache: ActivationCache,\n", ") -> Float[Tensor, \"batch pos head_index d_head\"]:\n", " \"\"\"\n", " Patches the attn patterns of a given head at every sequence position, using the value from the clean cache.\n", " \"\"\"\n", " raise NotImplementedError()\n", "\n", "\n", "def get_act_patch_attn_head_all_pos_every(\n", " model: HookedTransformer,\n", " corrupted_tokens: Float[Tensor, \"batch pos\"],\n", " clean_cache: ActivationCache,\n", " patching_metric: Callable,\n", ") -> Float[Tensor, \"layer head\"]:\n", " \"\"\"\n", " Returns an array of results of patching at all positions for each head in each layer (using the value from the clean\n", " cache) for output, queries, keys, values and attn pattern in turn.\n", "\n", " The results are calculated using the patching_metric function, which should be called on the model's logit output.\n", " \"\"\"\n", " raise NotImplementedError()\n", "\n", "\n", "act_patch_attn_head_all_pos_every_own = get_act_patch_attn_head_all_pos_every(\n", " model, corrupted_tokens, clean_cache, ioi_metric\n", ")\n", "\n", "t.testing.assert_close(act_patch_attn_head_all_pos_every, act_patch_attn_head_all_pos_every_own)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NLZgIOEXHu9L" }, "outputs": [], "source": [ "imshow(\n", " act_patch_attn_head_all_pos_every_own,\n", " facet_col=0,\n", " facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n", " title=\"Activation Patching Per Head (All Pos)\",\n", " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", " width=1200\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "UISgLz8fHu9L" }, "source": [ "
Solution\n", "\n", "```python\n", "def patch_attn_patterns(\n", " corrupted_head_vector: Float[Tensor, \"batch head_index pos_q pos_k\"],\n", " hook: HookPoint,\n", " head_index: int,\n", " clean_cache: ActivationCache,\n", ") -> Float[Tensor, \"batch pos head_index d_head\"]:\n", " \"\"\"\n", " Patches the attn patterns of a given head at every sequence position, using the value from the clean cache.\n", " \"\"\"\n", " corrupted_head_vector[:, head_index] = clean_cache[hook.name][:, head_index]\n", " return corrupted_head_vector\n", "\n", "\n", "def get_act_patch_attn_head_all_pos_every(\n", " model: HookedTransformer,\n", " corrupted_tokens: Float[Tensor, \"batch pos\"],\n", " clean_cache: ActivationCache,\n", " patching_metric: Callable,\n", ") -> Float[Tensor, \"layer head\"]:\n", " \"\"\"\n", " Returns an array of results of patching at all positions for each head in each layer (using the value from the clean\n", " cache) for output, queries, keys, values and attn pattern in turn.\n", "\n", " The results are calculated using the patching_metric function, which should be called on the model's logit output.\n", " \"\"\"\n", " results = t.zeros(5, model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=t.float32)\n", " # Loop over each component in turn\n", " for component_idx, component in enumerate([\"z\", \"q\", \"k\", \"v\", \"pattern\"]):\n", " for layer in tqdm(range(model.cfg.n_layers)):\n", " for head in range(model.cfg.n_heads):\n", " # Get different hook function if we're doing attention probs\n", " hook_fn_general = patch_attn_patterns if component == \"pattern\" else patch_head_vector\n", " hook_fn = partial(hook_fn_general, head_index=head, clean_cache=clean_cache)\n", " # Get patched logits\n", " patched_logits = model.run_with_hooks(\n", " corrupted_tokens, fwd_hooks=[(utils.get_act_name(component, layer), hook_fn)], return_type=\"logits\"\n", " )\n", " results[component_idx, layer, head] = patching_metric(patched_logits)\n", "\n", " return results\n", "\n", "\n", "act_patch_attn_head_all_pos_every_own = get_act_patch_attn_head_all_pos_every(\n", " model, corrupted_tokens, clean_cache, ioi_metric\n", ")\n", "\n", "t.testing.assert_close(act_patch_attn_head_all_pos_every, act_patch_attn_head_all_pos_every_own)\n", "\n", "imshow(\n", " act_patch_attn_head_all_pos_every_own,\n", " facet_col=0,\n", " facet_labels=[\"Output\", \"Query\", \"Key\", \"Value\", \"Pattern\"],\n", " title=\"Activation Patching Per Head (All Pos)\",\n", " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", " width=1200,\n", ")\n", "```\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "g9fCb_oRHu9L" }, "source": [ "Note - we can do this in an even more fine-grained way; the function `patching.get_act_patch_attn_head_by_pos_every` (i.e. same as above but replacing `all_pos` with `by_pos`) will give you the same decomposition, but by sequence position *as well as* by layer, head and component. The same holds for the `patching.get_act_patch_attn_head_out_all_pos` function earlier (replace `all_pos` with `by_pos`). These functions are unsurprisingly pretty slow though!" ] }, { "cell_type": "markdown", "metadata": { "id": "isETInp4Hu9L" }, "source": [ "This plot has some striking features. For instance, this shows us that we have at least three different groups of heads:\n", "\n", "* Earlier heads (`3.0`, `5.5`, `6.9`) which matter because of their attention patterns (specifically their query vectors).\n", "* Middle heads in layers 7 & 8 (`7.3`, `7.9`, `8.6`, `8.10`) seem to matter more because of their value vectors.\n", "* Later heads which improve the logit difference (`9.9`, `10.0`), which matter because of their query vectors.\n", "\n", "Question - what is the significance of the results for the middle heads (i.e. the important ones in layers 7 & 8)? In particular, how should we interpret the fact that value patching has a much bigger effect than the other two forms of patching?\n", "\n", "*Hint - if you're confused, try plotting the attention patterns of heads `7.3`, `7.9`, `8.6`, `8.10`. You can mostly reuse the code from above when we displayed the output of attention heads.*\n", "\n", "
\n", "Code to plot attention heads\n", "\n", "```python\n", "# Get the heads with largest value patching\n", "# (we know from plot above that these are the 4 heads in layers 7 & 8)\n", "k = 4\n", "top_heads = topk_of_Nd_tensor(act_patch_attn_head_all_pos_every[3], k=k)\n", "\n", "# Get all their attention patterns\n", "attn_patterns_for_important_heads: Float[Tensor, \"head q k\"] = t.stack([\n", " cache[\"pattern\", layer][:, head].mean(0)\n", " for layer, head in top_heads\n", "])\n", "\n", "# Display results\n", "display(HTML(f\"

Top {k} Logit Attribution Heads (from value-patching)

\"))\n", "display(cv.attention.attention_patterns(\n", " attention = attn_patterns_for_important_heads,\n", " tokens = model.to_str_tokens(tokens[0]),\n", " attention_head_names = [f\"{layer}.{head}\" for layer, head in top_heads],\n", "))\n", "```\n", "
\n", "\n", "
\n", "Answer\n", "\n", "The attention patterns show us that these heads attend from `END` to `S2`, so we can guess that they're responsible for moving information from `S2` to `END` which is used to determine the answer. This agrees with our earlier results, when we saw that most of the information gets moved over layers 7 & 8.\n", "\n", "The fact that value patching is the most important thing for them suggests that the interesting computation goes into **what information they move from `S2` to `end`**, rather than **why `end` attends to `S2`**. See the diagram below if you're confused why we can draw this inference.\n", "\n", "\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "0JG6kyWFHu9L" }, "source": [ "## Consolidating Understanding" ] }, { "cell_type": "markdown", "metadata": { "id": "6a4k8494Hu9L" }, "source": [ "OK, let's zoom out and reconsolidate. Here's a recap of the most important observations we have so far:\n", "\n", "* Heads `9.9`, `9.6`, and `10.0` are the most important heads in terms of directly writing to the residual stream. In all these heads, the `END` attends strongly to the `IO`.\n", " * We discovered this by taking the values written by each head in each layer to the residual stream, and projecting them along the logit diff direction by using `residual_stack_to_logit_diff`. We also looked at attention patterns using `circuitsvis`.\n", " * **This suggests that these heads are copying `IO` to `end`, to use it as the predicted next token.**\n", " * The question then becomes *\"how do these heads know to attend to this token, and not attend to `S`?\"*\n", "\n", "
\n", "\n", "* All the action is on `S2` until layer 7 and then transitions to `END`. And that attention layers matter a lot, MLP layers not so much (apart from MLP0, likely as an extended embedding).\n", " * We discovered this by doing **activation patching** on `resid_pre`, `attn_out`, and `mlp_out`.\n", " * **This suggests that there is a cluster of heads in layers 7 & 8, which move information from `S2` to `END`. We deduce that this information is how heads `9.9`, `9.6` and `10.0` know to attend to `IO`.**\n", " * The question then becomes *\"what is this information, how does it end up in the `S2` token, and how does `END` know to attend to it?\"*\n", "\n", "
\n", "\n", "* The significant heads in layers 7 & 8 are `7.3`, `7.9`, `8.6`, `8.10`. These heads have high activation patching values for their value vectors, less so for their queries and keys.\n", " * We discovered this by doing **activation patching** on the value inputs for these heads.\n", " * **This supports the previous observation, and it tells us that the interesting computation goes into *what gets moved* from `S2` to `END`, rather than the fact that `END` attends to `S2`.**.\n", " * We still don't know: *\"what is this information, and how does it end up in the `S2` token?\"*\n", "\n", "
\n", "\n", "* As well as the 2 clusters of heads given above, there's a third cluster of important heads: early heads (e.g. `3.0`, `5.5`, `6.9`) whose query vectors are particularly important for getting good performance.\n", " * We discovered this by doing **activation patching** on the query inputs for these heads." ] }, { "cell_type": "markdown", "metadata": { "id": "haqHP9V6Hu9L" }, "source": [ "With all this in mind, can you come up with a theory for what these three heads are doing, and come up with a simple model of the whole circuit?\n", "\n", "*Hint - if you're still stuck, try plotting the attention pattern of head `3.0`. The patterns of `5.5` and `6.9` might seem a bit confusing at first (they add complications to the \"simplest possible picture\" of how the circuit works); we'll discuss them later so they don't get in the way of understanding the core of the circuit.*\n", "\n", "
\n", "Answer (and simple diagram of circuit)\n", "\n", "If you plotted the attention pattern for head `3.0`, you should have seen that `S2` paid attention to `S1`. This suggests that the early heads are detecting when the destination token is a duplicate. So the information that the subject is a duplicate gets stored in `S2`.\n", "\n", "How can the information that the subject token is a duplicate help us predict the token after `end`? Well, the correct answer (the `IO` token) is the non-duplicated token. So we can infer that the information that the subject token is a duplicate is used to *inhibit* the attention of the late heads to the duplicated token, and they instead attend to the non-duplicated token.\n", "\n", "To summarise the second half of the circuit: information about this duplicated token is then moved from `S2` to `end` by the middle cluster of heads `7.3`, `7.9`, `8.6` and `8.10`, and this information goes into the queries of the late heads `9.9`, `9.6` and `10.0`, making them *inhibit* their attention to the duplicated token. Instead, they attend to `IO` (copying this token directly to the logits).\n", "\n", "This picture of the circuit turns out to be mostly right. It misses out on some subtleties which we'll discuss shortly, but it's a good rough picture to have in your head. We might illustrate this as follows:\n", "\n", "\n", "\n", "Explanation:\n", "\n", "* We call the early heads **DTH** (duplicate token heads), their job is to detect that `S2` is a duplicate.\n", "* The second group of heads are called **SIH** (S-inhibition heads), their job is to move the duplicated token information from `S2` to `END`. We've illustrated this as them moving the positional information, but in principle this could also be token embedding information (more on this in the final section).\n", "* The last group of heads are called **NMH** (name mover heads), their job is to copy the `IO` token to the `END` token, where it is used as the predicted next token (thanks to the S-inihbition heads, these heads don't pay attention to the `S` token).\n", "\n", "Note - if you're still confused about how to interpret this diagram, but you understand induction circuits and how they work, it might help to compare this diagram to one written in the same style which I made for [induction circuits](https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/ih-simple.png). Also, if you've read my induction heads [LessWrong post](https://www.lesswrong.com/posts/TvrfY4c9eaGLeyDkE/induction-heads-illustrated) and you're confused about how this style of diagram is different from that one, [here](https://raw.githubusercontent.com/info-arena/ARENA_img/main/misc/ih-compared.png) is an image comparing the two diagrams (for induction heads) and explaining how they differ.\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "oQEfzprgHu9L" }, "source": [ "Now, let's flesh out this picture a bit more by comparing our results to the paper results. Below is a more complicated version of the diagram in the dropdown above, which also labels the important heads. The diagram is based on the paper's [original diagram](https://res.cloudinary.com/lesswrong-2-0/image/upload/v1672942728/mirroredImages/3ecs6duLmTfyra3Gp/h5icqzpyuhu4mqvfjhvw.png). Don't worry if you don't understand everything in this diagram; the boundaries of the circuit are fuzzy and the \"role\" of every head is in this circuit is a leaky abstraction. Rather, this diagram is meant to point your intuitions in the right direction for better understanding this circuit.\n", "\n", "
\n", "Diagram of large circuit\n", "\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "AfkNB0o7Hu9L" }, "source": [ "Here are the main ways it differs from the one above:" ] }, { "cell_type": "markdown", "metadata": { "id": "sIcTPPl1Hu9L" }, "source": [ "#### Induction heads\n", "\n", "Rather than just having duplicate token heads in the first cluster of heads, we have two other types of heads as well: previous token heads and induction heads. The induction heads do the same thing as the duplicate token heads, via an induction mechanism. They cause token `S2` to attend to `S1+1` (mediated by the previous token heads), and their output is used as both a pointer to `S1` and as a signal that `S1` is duplicated (more on the distinction between these two in the paragraph \"Position vs token information being moved\" below).\n", "\n", "*(Note - the original paper's diagram implies the induction heads and duplicate token heads compose with each other. This is misleading, and is not the case.)*\n", "\n", "Why are induction heads used in this circuit? We'll dig into this more in the bonus section, but one likely possibility is that induction heads are just a thing that forms very early on in training by default, and so it makes sense for the model to repurpose this already-existing machinery for this job. See [this paper](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html) for more on induction heads, and how / why they form." ] }, { "cell_type": "markdown", "metadata": { "id": "TP7AC6FqHu9L" }, "source": [ "#### Negative & Backup name mover heads\n", "\n", "Earlier, we saw that some heads in later layers were actually harming performance. These heads turn out to be doing something pretty similar to name mover heads, but in reverse (i.e. they inhibit the correct answer). It's not obvious why the model does this; the paper speculates that these heads might help the model \"hedge\" so as to avoid high cross-entropy loss when making mistakes.\n", "\n", "Backup name mover heads are possibly even weirder. It turns out that when we **ablate** the name mover heads, these ones pick up the slack and do the task anyway (even though they don't seem to do it when the NMHs aren't ablated). This is an example of **built-in redundancy** in the model. One possible explanation is that this resulted from the model being trained with dropout, although this explanation isn't fully satisfying (models trained without dropout still seem to have BNMHs, although they aren't as strong as they are in this model). Like with induction heads, we'll dig into this more in the final section." ] }, { "cell_type": "markdown", "metadata": { "id": "pRR-evqxHu9L" }, "source": [ "#### Positional vs token information\n", "\n", "There are 2 kinds of S-inhibition heads shown in the diagram - ones that inhibit based on positional information (pink), and ones that inhibit based on token information (purple). It's not clear which heads are doing which (and in fact some heads might be doing both!).\n", "\n", "The paper has an ingenious way of teasing apart which type of information is being used by which of the S-inhibition heads, which we'll discuss in the final section." ] }, { "cell_type": "markdown", "metadata": { "id": "f3B-oSZMHu9L" }, "source": [ "#### K-composition in S-inhibition heads\n", "\n", "When we did activation patching on the keys and values of S-inhibition heads, we found that the values were important and the keys weren't. We concluded that K-composition isn't really happening in these heads, and `END` must be paying attention to `S2` for reasons other than the duplicate token information (e.g. it might just be paying attention to the closest name, or to any names which aren't separated from it by a comma). Although this is mostly true, it turns out that there is a bit of K-composition happening in these heads. We can think of this as the duplicate token heads writing the \"duplicated\" flag to the residual stream (without containing any information about the identity and position of this token), and this flag is being used by the keys of the S-inhibition heads (i.e. they make `END` pay attention to `S2`). In the diagram, this is represented by the dark grey boxes (rather than just the light grey boxes we had in the simplified version). We haven't seen any evidence for this happening yet, but we will in the next section (when we look at path patching).\n", "\n", "Note - whether the early heads are writing positional information or \"duplicate flag\" information to the residual stream is not necessarily related to whether the head is an induction head or a duplicate token head. In principle, either type of head could write either type of information." ] }, { "cell_type": "markdown", "metadata": { "id": "Lqn-AYqHHu9L" }, "source": [ "# 4️⃣ Path Patching\n", "\n", "> ##### Learning Objectives\n", ">\n", "> * Understand the idea of path patching, and how it differs from activation patching\n", "> * Implement path patching from scratch (i.e. using hooks)\n", "> * Replicate several of the results in the [IOI paper](https://arxiv.org/abs/2211.00593)" ] }, { "cell_type": "markdown", "metadata": { "id": "rxq5rZiHHu9L" }, "source": [ "This section will be a lot less conceptual and exploratory than the last two sections, and a lot more technical and rigorous. You'll learn what path patching is and how it works, and you'll use it to replicate many of the paper's results (as well as some other paper results not related to path patching)." ] }, { "cell_type": "markdown", "metadata": { "id": "lPxM7lOCHu9L" }, "source": [ "## Setup" ] }, { "cell_type": "markdown", "metadata": { "id": "texnWG5YHu9L" }, "source": [ "Here, we'll be more closely following the setup that the paper's authors used, rather than the rough-and-ready exploration we used in the first few sections. To be clear, a lot of the rigour that we'll be using in the setup here isn't necessary if you're just starting to investigate a model's circuit. This rigour is necessary if you're publishing a paper, but it can take a lot of time and effort!" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Iarj_J_iHu9L" }, "outputs": [], "source": [ "from part41_indirect_object_identification.ioi_dataset import NAMES, IOIDataset" ] }, { "cell_type": "markdown", "metadata": { "id": "VloeXl2BHu9L" }, "source": [ "The dataset we'll be using is an instance of `IOIDataset`, which is generated by randomly choosing names from the `NAMES` list (as well as sentence templates and objects from different lists). You can look at the `ioi_dataset.py` file to see details of how this is done.\n", "\n", "(Note - you can reduce `N` if you're getting memory errors from running this code. If you're still getting memory errors from `N = 10` then you're recommended to switch to Colab, or to use a virtual machine e.g. via Lambda Labs.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "G6fEE--sHu9L" }, "outputs": [], "source": [ "N = 25\n", "ioi_dataset = IOIDataset(\n", " prompt_type=\"mixed\",\n", " N=N,\n", " tokenizer=model.tokenizer,\n", " prepend_bos=False,\n", " seed=1,\n", " device=str(device),\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "UwzM-HnNHu9L" }, "source": [ "This dataset has a few useful attributes & methods. Here are the main ones you should be aware of for these exercises:\n", "\n", "* `toks` is a tensor of shape `(batch_size, max_seq_len)` containing the token IDs (i.e. this is what you pass to your model)\n", "* `s_tokenIDs` and `io_tokenIDs` are lists containing the token IDs for the subjects and objects\n", "* `sentences` is a list containing the sentences (as strings)\n", "* `word_idx` is a dictionary mapping word types (e.g. `\"S1\"`, `\"S2\"`, `\"IO\"` or `\"end\"`) to tensors containing the positions of those words for each sequence in the dataset.\n", " * This is particularly handy for indexing, since the positions of the subject, indirect object, and end tokens are no longer the same in every sentence like they were in previous sections." ] }, { "cell_type": "markdown", "metadata": { "id": "3-XceXYsHu9L" }, "source": [ "Firstly, what dataset should we use for patching? In the previous section we just flipped the subject and indirect object tokens around, which meant the direction of the signal was flipped around. However, what we'll be doing here is a bit more principled - rather than flipping the IOI signal, we'll be erasing it. We do this by constructing a new dataset from `ioi_dataset` which replaces every name with a different random name. This way, the sentence structure stays the same, but all information related to the actual indirect object identification task (i.e. the identities and positions of repeated names) has been erased.\n", "\n", "For instance, given the sentence `\"When John and Mary went to the shops, John gave the bag to Mary\"`, the corresponding sentence in the ABC dataset might be `\"When Edward and Laura went to the shops, Adam gave the bag to Mary\"`. We would expect the residual stream for the latter prompt to carry no token or positional information which could help it solve the IOI task (i.e. favouring `Mary` over `John`, or favouring the 2nd token over the 4th token).\n", "\n", "We define this dataset below. Note the syntax of the `gen_flipped_prompts` method - the letters tell us how to replace the names in the sequence. For instance, `ABB->XYZ` tells us to take sentences of the form `\"When Mary and John went to the store, John gave a drink to Mary\"` with `\"When [X] and [Y] went to the store, [Z] gave a drink to Mary\"` for 3 independent randomly chosen names `[X]`, `[Y]` and `[Z]`. We'll use this function more in the bonus section, when we're trying to disentangle positional and token signals (since we can also do fun things like `ABB->BAB` to swap the first two names, etc)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5IEjdR7RHu9L" }, "outputs": [], "source": [ "abc_dataset = ioi_dataset.gen_flipped_prompts(\"ABB->XYZ, BAB->XYZ\")" ] }, { "cell_type": "markdown", "metadata": { "id": "yfrolx0WHu9L" }, "source": [ "Let's take a look at this dataset. We'll define a helper function `make_table`, which prints out tables after being fed columns rather than rows (don't worry about the syntax, it's not important)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xAH7BYdZHu9L" }, "outputs": [], "source": [ "def format_prompt(sentence: str) -> str:\n", " \"\"\"Format a prompt by underlining names (for rich print)\"\"\"\n", " return re.sub(\"(\" + \"|\".join(NAMES) + \")\", lambda x: f\"[u bold dark_orange]{x.group(0)}[/]\", sentence) + \"\\n\"\n", "\n", "\n", "def make_table(cols, colnames, title=\"\", n_rows=5, decimals=4):\n", " \"\"\"Makes and displays a table, from cols rather than rows (using rich print)\"\"\"\n", " table = Table(*colnames, title=title)\n", " rows = list(zip(*cols))\n", " f = lambda x: x if isinstance(x, str) else f\"{x:.{decimals}f}\"\n", " for row in rows[:n_rows]:\n", " table.add_row(*list(map(f, row)))\n", " rprint(table)\n", "\n", "\n", "make_table(\n", " colnames=[\"IOI prompt\", \"IOI subj\", \"IOI indirect obj\", \"ABC prompt\"],\n", " cols=[\n", " map(format_prompt, ioi_dataset.sentences),\n", " model.to_string(ioi_dataset.s_tokenIDs).split(),\n", " model.to_string(ioi_dataset.io_tokenIDs).split(),\n", " map(format_prompt, abc_dataset.sentences),\n", " ],\n", " title=\"Sentences from IOI vs ABC distribution\",\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "Vs92NppuHu9L" }, "source": [ "Next, we'll define functions similar to the ones from previous sections. We've just given you these, rather than making you repeat the exercise of writing them (although you should compare these functions to the ones you wrote earlier, and make sure you understand how they work).\n", "\n", "We'll call these functions something slightly different, so as not to pollute namespace." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8BKvdsOdHu9L" }, "outputs": [], "source": [ "def logits_to_ave_logit_diff_2(\n", " logits: Float[Tensor, \"batch seq d_vocab\"], ioi_dataset: IOIDataset = ioi_dataset, per_prompt=False\n", ") -> Float[Tensor, \"*batch\"]:\n", " \"\"\"\n", " Returns logit difference between the correct and incorrect answer.\n", "\n", " If per_prompt=True, return the array of differences rather than the average.\n", " \"\"\"\n", " # Only the final logits are relevant for the answer\n", " # Get the logits corresponding to the indirect object / subject tokens respectively\n", " io_logits: Float[Tensor, \"batch\"] = logits[\n", " range(logits.size(0)), ioi_dataset.word_idx[\"end\"], ioi_dataset.io_tokenIDs\n", " ]\n", " s_logits: Float[Tensor, \"batch\"] = logits[\n", " range(logits.size(0)), ioi_dataset.word_idx[\"end\"], ioi_dataset.s_tokenIDs\n", " ]\n", " # Find logit difference\n", " answer_logit_diff = io_logits - s_logits\n", " return answer_logit_diff if per_prompt else answer_logit_diff.mean()\n", "\n", "\n", "model.reset_hooks(including_permanent=True)\n", "\n", "ioi_logits_original, ioi_cache = model.run_with_cache(ioi_dataset.toks)\n", "abc_logits_original, abc_cache = model.run_with_cache(abc_dataset.toks)\n", "\n", "ioi_per_prompt_diff = logits_to_ave_logit_diff_2(ioi_logits_original, per_prompt=True)\n", "abc_per_prompt_diff = logits_to_ave_logit_diff_2(abc_logits_original, per_prompt=True)\n", "\n", "ioi_average_logit_diff = logits_to_ave_logit_diff_2(ioi_logits_original).item()\n", "abc_average_logit_diff = logits_to_ave_logit_diff_2(abc_logits_original).item()\n", "\n", "print(f\"Average logit diff (IOI dataset): {ioi_average_logit_diff:.4f}\")\n", "print(f\"Average logit diff (ABC dataset): {abc_average_logit_diff:.4f}\")\n", "\n", "make_table(\n", " colnames=[\"IOI prompt\", \"IOI logit diff\", \"ABC prompt\", \"ABC logit diff\"],\n", " cols=[\n", " map(format_prompt, ioi_dataset.sentences),\n", " ioi_per_prompt_diff,\n", " map(format_prompt, abc_dataset.sentences),\n", " abc_per_prompt_diff,\n", " ],\n", " title=\"Sentences from IOI vs ABC distribution\",\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "elzT3JFxHu9M" }, "source": [ "Note that we're always measuring performance ***with respect to the correct answers for the IOI dataset, not the ABC dataset***, because we want our ABC dataset to carry no information that helps with the IOI task (hence patching it in gives us signals which are totally uncorrelated with the correct answer). For instance, the model will obviously not complete sentences like `\"When Max and Victoria got a snack at the store, Clark decided to give it to\"` with the name `\"Tyler\"`." ] }, { "cell_type": "markdown", "metadata": { "id": "oA2a4p90Hu9M" }, "source": [ "Finally, let's define a new `ioi_metric` function which works for our new data.\n", "\n", "In order to match the paper's results, we'll use a different convention here. 0 means performance is the same as on the IOI dataset (i.e. hasn't been harmed in any way), and -1 means performance is the same as on the ABC dataset (i.e. the model has completely lost the ability to distinguish between the subject and indirect object).\n", "\n", "Again, we'll call this function something slightly different." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dZiUdSurHu9M" }, "outputs": [], "source": [ "def ioi_metric_2(\n", " logits: Float[Tensor, \"batch seq d_vocab\"],\n", " clean_logit_diff: float = ioi_average_logit_diff,\n", " corrupted_logit_diff: float = abc_average_logit_diff,\n", " ioi_dataset: IOIDataset = ioi_dataset,\n", ") -> float:\n", " \"\"\"\n", " We calibrate this so that the value is 0 when performance isn't harmed (i.e. same as IOI dataset),\n", " and -1 when performance has been destroyed (i.e. is same as ABC dataset).\n", " \"\"\"\n", " patched_logit_diff = logits_to_ave_logit_diff_2(logits, ioi_dataset)\n", " return (patched_logit_diff - clean_logit_diff) / (clean_logit_diff - corrupted_logit_diff)\n", "\n", "\n", "print(f\"IOI metric (IOI dataset): {ioi_metric_2(ioi_logits_original):.4f}\")\n", "print(f\"IOI metric (ABC dataset): {ioi_metric_2(abc_logits_original):.4f}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "PmJPbCk3Hu9M" }, "source": [ "## What is path patching?" ] }, { "cell_type": "markdown", "metadata": { "id": "XmNXxZDIHu9M" }, "source": [ "In the previous section, we looked at activation patching, which answers questions like *what would happen if you took an attention head, and swapped the value it writes to the residual stream with the value it would have written under a different distribution, while keeping everything else the same?*. This proved to be a good way to examine the role of individual components like attention heads, and it allowed us to perform some more subtle analysis like patching keys / queries / values in turn to figure out which of them were more important for which heads.\n", "\n", "However, when we're studying a circuit, rather than just swapping out an entire attention head, we might want to ask more nuanced questions like *what would happen if the direct input from attention head $A$ to head $B$ (where $B$ comes after $A$) was swapped out with the value it would have been under a different distribution, while keeping everything else the same?*. Rather than answering the general question of how important attention heads are, this answers the more specific question of how important the circuit formed by connecting up these two attention heads is. Path patching is designed to answer questions like these." ] }, { "cell_type": "markdown", "metadata": { "id": "ygwJG70mHu9M" }, "source": [ "The following diagrams might help explain the difference between activation and path patching in transformers. Recall that activation patching looked like:\n", "\n", "\n", "\n", "where the black and green distributions are our clean and corrupted datasets respectively (so this would be `ioi_dataset` and `abc_dataset`). In contrast, path patching involves replacing **edges** rather than **nodes**. In the diagram below, we're replacing the edge $D \\to G$ with what it would be on the corrupted distribution. So in our patched run, $G$ is calculated just like it would be on the clean distribution, but as if the **direct** input from $D$ had come from the corrupted distribution instead.\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "R2ZnZdJxHu9M" }, "source": [ "Unfortunately, for a transformer, this is easier to describe than to actually implement. This is because the \"nodes\" are attention heads, and the \"edges\" are all tangled together in the residual stream (that is to say, it's not clear how one could change the value of one edge without without affecting every path that includes that edge). The solution is to use the 3-step algorithm shown in the diagram below (which reads from right to left).\n", "\n", "Terminology note - we call head $D$ the **sender node**, and head $G$ the **receiver node**. Also, by \"freezing\" nodes, we mean \"patch with the value that is the same as the input\". For instance, if we didn't freeze head $H$ in step 2 below, it would have a different value because it would be affected by the corrupted value of head $D$.\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "okmuKGx2Hu9M" }, "source": [ "Let's make this concrete, and take a simple 3-layer transformer with 2 heads per layer. Let's perform path patching on the edge from head `0.0` to `2.0` (terminology note: `0.0` is the **sender**, and `2.0` is the **receiver**). Note that here, we're considering \"direct paths\" as anything that doesn't go through another attention head (so it can go through any combination of MLPs). Intuitively, the nodes (attention heads) are the only things that can move information around in the model, and this is the thing we want to study. In contrast, MLPs just perform information processing, and they're not as interesting for this task.\n", "\n", "Our 3-step process looks like the diagram below (remember green is corrupted, grey is clean).\n", "\n", "\n", "\n", "(Note - in this diagram, the uncoloured nodes indicate we aren't doing any patching; we're just allowing them to be computed from the values of nodes which are downstream of it.)" ] }, { "cell_type": "markdown", "metadata": { "id": "fDukkj77Hu9M" }, "source": [ "Why does this work? If you stare at the middle picture above for long enough, you'll realise that the contribution from every non-direct path from `0.0` $\\to$ `2.0` is the same as it would be on the clean distribution, while all the direct paths' contributions are the same as they would be on the corrupted distribution.\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "_e6dpZFLHu9M" }, "source": [ "### Why MLPs?" ] }, { "cell_type": "markdown", "metadata": { "id": "lKwMUW0eHu9M" }, "source": [ "You might be wondering why we're including MLPs as part of our direct path. The short answer is that this is what the IOI paper does, and we're trying to replicate it! The slightly longer answer is that both this method and a method which doesn't count MLPs as the direct path are justifiable.\n", "\n", "To take one example, suppose the output of head `0.0` is being used directly by head `2.0`, but one of the MLPs is acting as a mediator. To oversimplify, we might imagine that `0.0` writes the vector $v$ into the residual stream, some neuron detects $v$ and writes $w$ to the residual stream, and `2.0` detects $w$. If we didn't count MLPs as a direct path then we wouldn't catch this causal relationship. The drawback is that things get a bit messier, because now we're essentially passing a \"fake input\" into our MLPs, and it's dangerous to assume that any operation as clean as the one previously described (with vectors $v$, $w$) would still happen under these new circumstances.\n", "\n", "Also, having MLPs as part of the direct path doesn't help us understand what role the MLPs play in the circuit, all it does is tell us that some of them are important! Luckily, in the IOI circuit, MLPs aren't important (except for MLP0), and so doing both these forms of path patching get pretty similar results. As an optional exercise, you can reproduce the results from the following few sections using this different form of path patching. It's actually algorithmically easier to implement, because we only need one forward pass rather than two. Can you see why?\n", "\n", "
\n", "Answer\n", "\n", "Because the MLPs were part of the direct paths between sender and receiver in the previous version of the algorithm, we had to do a forward pass to find the value we'd be patching into the receivers. But if MLPs aren't part of the direct path, then we can directly compute what to patch into the receiver nodes:\n", "\n", "```\n", "orig_receiver_input <- orig_receiver_input + (new_sender_output - old_sender_output)\n", "```\n", "\n", "Diagram with direct paths not including MLPs:\n", "\n", "\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "7qZMKz2RHu9M" }, "source": [ "## Path Patching: Name Mover Heads" ] }, { "cell_type": "markdown", "metadata": { "id": "BKPM65tFHu9M" }, "source": [ "We'll start with a simple type of path patching - with just one receiver node, which is the final value of the residual stream. We've only discussed receiver nodes being other attention heads so far, but the same priciples hold for any choice of receiver nodes.\n", "\n", "
\n", "Question - can you explain the difference between path patching from an attention head to the residual stream, and activation patching on that attention head?\n", "\n", "Activation patching changes the value of that head, and all subsequent layers which depend on that head.\n", "\n", "Path patching will answer the question \"what if the value written by the head directly to the residual stream was the same as in $x_{new}$, but every non-direct path from this head to the residual stream (i.e. paths going through other heads) the value was the same as it would have been under $x_{orig}$?\n", "
\n", "\n", "This patching is described at the start of section 3.1 in [the paper](https://arxiv.org/pdf/2211.00593.pdf) (page 5). The 3-step process will look like:\n", "\n", "1. Run the model on clean and corrupted input. Cache the head outputs.\n", "2. Run the model on clean input, with the sender head **patched** from the corrupted input, and every other head **frozen** to their values on the clean input. Cache the final value of the residual stream (i.e. `resid_post` in the final layer).\n", "3. Normally we would re-run the model on the clean input and patch in the cached value of the final residual stream, but in this case we don't need to because we can just unembed the final value of the residual stream directly without having to run another forward pass.\n", "\n", "Here is an illustration for a 2-layer transformer:\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "yqyFleazHu9M" }, "source": [ "### Exercise - implement path patching to the final residual stream value\n", "\n", "> ```yaml\n", "> Difficulty: 🔴🔴🔴🔴🔴\n", "> Importance: 🔵🔵🔵🔵⚪\n", ">\n", "> You should spend up to 30-45 minutes on this exercise.\n", ">\n", "> Path patching is a very challenging algorithm with many different steps.\n", "> ```\n", "\n", "You should implement path patching from heads to the residual stream, as described above (and in the paper).\n", "\n", "This exercise is expected to be challenging, with several moving parts. We've purposefully left it very open-ended, without even giving you a docstring for the function you'll be writing.\n", "\n", "Here are a few hints / tips for how to proceed:\n", "\n", "* Split your function up into 3 parts (one for each of the steps above), and write each section one at a time.\n", "* You'll need a new hook function: one which performs freezing / patching for step 2 of the algorithm.\n", "* You can reuse a lot of code from your activation patching function.\n", "* When calling `model.run_with_cache`, you can use the keyword argument `names_filter`, which is a function from name to boolean. If you use this argument, your model will only cache activtions with a name which passes this filter (e.g. you can use it like `names_filter = lambda name: name.endswith(\"q\")` to only cache query vectors).\n", "\n", "You can also look at the dropdowns to get more hints and guidance (e.g. if you want to start from a function docstring).\n", "\n", "You'll know you've succeeded if you can plot the results, and replicate Figure 3(b) from [the paper](https://arxiv.org/pdf/2211.00593.pdf) (at the top of page 6).\n", "\n", "**Note - if you use `model.add_hook` then `model.run_with_cache`, you might have to pass the argument `level=1` to the `add_hook` method. I don't know why the function sometimes fails unless you do this (this bug only started appearing after the exercises were written). I've not had time to track this down, but extra credit to anyone who can (-:**" ] }, { "cell_type": "markdown", "metadata": { "id": "V7bAkQqCHu9M" }, "source": [ "
\n", "Click here to get a docstring for the main function.\n", "\n", "```python\n", "def get_path_patch_head_to_final_resid_post(\n", " model: HookedTransformer,\n", " patching_metric: Callable,\n", " new_dataset: IOIDataset = abc_dataset,\n", " orig_dataset: IOIDataset = ioi_dataset,\n", " new_cache: ActivationCache | None = abc_cache,\n", " orig_cache: ActivationCache | None = ioi_cache,\n", ") -> Float[Tensor, \"layer head\"]:\n", " '''\n", " Performs path patching (see algorithm in appendix B of IOI paper), with:\n", "\n", " sender head = (each head, looped through, one at a time)\n", " receiver node = final value of residual stream\n", "\n", " Returns:\n", " tensor of metric values for every possible sender head\n", " '''\n", " pass\n", "```\n", "
\n", "\n", "
\n", "Click here to get a docstring for the main function, plus some annotations and function structure.\n", "\n", "```python\n", "def get_path_patch_head_to_final_resid_post(\n", " model: HookedTransformer,\n", " patching_metric: Callable,\n", " new_dataset: IOIDataset = abc_dataset,\n", " orig_dataset: IOIDataset = ioi_dataset,\n", " new_cache: ActivationCache | None = abc_cache,\n", " orig_cache: ActivationCache | None = ioi_cache,\n", ") -> Float[Tensor, \"layer head\"]:\n", " '''\n", " Performs path patching (see algorithm in appendix B of IOI paper), with:\n", "\n", " sender head = (each head, looped through, one at a time)\n", " receiver node = final value of residual stream\n", "\n", " Returns:\n", " tensor of metric values for every possible sender head\n", " '''\n", " model.reset_hooks()\n", " results = t.zeros(model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=t.float32)\n", "\n", " # ========== Step 1 ==========\n", " # Gather activations on x_orig and x_new\n", "\n", " # YOUR CODE HERE\n", "\n", "\n", " # Using itertools to loop gives us a smoother progress bar (using nested for loops is also fine)\n", " for (sender_layer, sender_head) in tqdm_notebook(list(itertools.product(\n", " range(model.cfg.n_layers),\n", " range(model.cfg.n_heads)\n", " ))):\n", " pass\n", "\n", " # ========== Step 2 ==========\n", " # Run on x_orig, with sender head patched from x_new, every other head frozen\n", "\n", " # YOUR CODE HERE\n", "\n", "\n", " # ========== Step 3 ==========\n", " # Unembed the final residual stream value, to get our patched logits\n", "\n", " # YOUR CODE HERE\n", "\n", "\n", " # Save the results\n", " results[sender_layer, sender_head] = patching_metric(patched_logits)\n", "\n", "\n", " return results\n", "```\n", "
" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7ptPt_nDHu9M" }, "outputs": [], "source": [ "def patch_or_freeze_head_vectors(\n", " orig_head_vector: Float[Tensor, \"batch pos head_index d_head\"],\n", " hook: HookPoint,\n", " new_cache: ActivationCache,\n", " orig_cache: ActivationCache,\n", " head_to_patch: tuple[int, int],\n", ") -> Float[Tensor, \"batch pos head_index d_head\"]:\n", " \"\"\"\n", " This helps implement step 2 of path patching. We freeze all head outputs (i.e. set them to their values in\n", " orig_cache), except for head_to_patch (if it's in this layer) which we patch with the value from new_cache.\n", "\n", " head_to_patch: tuple of (layer, head)\n", " \"\"\"\n", " # Setting using ..., otherwise changing orig_head_vector will edit cache value too\n", " orig_head_vector[...] = orig_cache[hook.name][...]\n", " if head_to_patch[0] == hook.layer():\n", " orig_head_vector[:, :, head_to_patch[1]] = new_cache[hook.name][:, :, head_to_patch[1]]\n", " return orig_head_vector\n", "\n", "\n", "def get_path_patch_head_to_final_resid_post(\n", " model: HookedTransformer,\n", " patching_metric: Callable,\n", " new_dataset: IOIDataset = abc_dataset,\n", " orig_dataset: IOIDataset = ioi_dataset,\n", " new_cache: ActivationCache | None = abc_cache,\n", " orig_cache: ActivationCache | None = ioi_cache,\n", ") -> Float[Tensor, \"layer head\"]:\n", " \"\"\"\n", " Performs path patching (see algorithm in appendix B of IOI paper), with:\n", "\n", " sender head = (each head, looped through, one at a time)\n", " receiver node = final value of residual stream\n", "\n", " Returns:\n", " tensor of metric values for every possible sender head\n", " \"\"\"\n", " raise NotImplementedError()\n", "\n", "\n", "path_patch_head_to_final_resid_post = get_path_patch_head_to_final_resid_post(model, ioi_metric_2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WIDPpKybHu9M" }, "outputs": [], "source": [ "imshow(\n", " 100 * path_patch_head_to_final_resid_post,\n", " title=\"Direct effect on logit difference\",\n", " labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"Logit diff. variation\"},\n", " coloraxis=dict(colorbar_ticksuffix=\"%\"),\n", " width=600,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "-EN22BWjHu9M" }, "source": [ "
\n", "Help - all the values in my heatmap are the same.\n", "\n", "There could be a few possible reasons for this. A common one is that you're changing an actual tensor, rather than just changing its values - this means when one tensor changes, the other one does too. For instance, if you do something like:\n", "\n", "```python\n", "x = t.zeros(3)\n", "y = x\n", "x[0] = 1\n", "print(y)\n", "```\n", "\n", "then `y` will also be `[1, 0, 0]`. To avoid this, you can use the `...` syntax, which means \"set all values in this tensor to the values in this other tensor\". For instance, if you do:\n", "\n", "```python\n", "x = t.zeros(3)\n", "y = t.zeros(3)\n", "x[...] = y\n", "x[0] = 1\n", "print(y)\n", "```\n", "\n", "then `y` will still be `[0, 0, 0]`.\n", "\n", "Using `x[:] = y` will also work.\n", "\n", "---\n", "\n", "Another possible explanation would be passing in the wrong input values / cache at some point in the algorithm, or freezing to the wrong values. Remember that in the diagram, grey represents original values (clean) and blue represents new values (corrupted), so e.g. in step 2 we want to run the model on `orig_dataset` (= IOI dataset) and we also want to freeze all non-sender heads to their values in `orig_cache`.\n", "\n", "---\n", "\n", "Lastly, make sure you're not freezing your heads in a way that doesn't override the sender patching! If more than one hook function is added to a hook point, they're executed in the order they were added (with the last one possibly overriding the previous ones).\n", "\n", "
\n", "\n", "\n", "
Solution\n", "\n", "```python\n", "def patch_or_freeze_head_vectors(\n", " orig_head_vector: Float[Tensor, \"batch pos head_index d_head\"],\n", " hook: HookPoint,\n", " new_cache: ActivationCache,\n", " orig_cache: ActivationCache,\n", " head_to_patch: tuple[int, int],\n", ") -> Float[Tensor, \"batch pos head_index d_head\"]:\n", " \"\"\"\n", " This helps implement step 2 of path patching. We freeze all head outputs (i.e. set them to their values in\n", " orig_cache), except for head_to_patch (if it's in this layer) which we patch with the value from new_cache.\n", "\n", " head_to_patch: tuple of (layer, head)\n", " \"\"\"\n", " # Setting using ..., otherwise changing orig_head_vector will edit cache value too\n", " orig_head_vector[...] = orig_cache[hook.name][...]\n", " if head_to_patch[0] == hook.layer():\n", " orig_head_vector[:, :, head_to_patch[1]] = new_cache[hook.name][:, :, head_to_patch[1]]\n", " return orig_head_vector\n", "\n", "\n", "def get_path_patch_head_to_final_resid_post(\n", " model: HookedTransformer,\n", " patching_metric: Callable,\n", " new_dataset: IOIDataset = abc_dataset,\n", " orig_dataset: IOIDataset = ioi_dataset,\n", " new_cache: ActivationCache | None = abc_cache,\n", " orig_cache: ActivationCache | None = ioi_cache,\n", ") -> Float[Tensor, \"layer head\"]:\n", " \"\"\"\n", " Performs path patching (see algorithm in appendix B of IOI paper), with:\n", "\n", " sender head = (each head, looped through, one at a time)\n", " receiver node = final value of residual stream\n", "\n", " Returns:\n", " tensor of metric values for every possible sender head\n", " \"\"\"\n", " model.reset_hooks()\n", " results = t.zeros(model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=t.float32)\n", "\n", " resid_post_hook_name = utils.get_act_name(\"resid_post\", model.cfg.n_layers - 1)\n", " resid_post_name_filter = lambda name: name == resid_post_hook_name\n", "\n", " # ========== Step 1 ==========\n", " # Gather activations on x_orig and x_new\n", "\n", " # Note the use of names_filter for the run_with_cache function. Using it means we\n", " # only cache the things we need (in this case, just attn head outputs).\n", " z_name_filter = lambda name: name.endswith(\"z\")\n", " if new_cache is None:\n", " _, new_cache = model.run_with_cache(new_dataset.toks, names_filter=z_name_filter, return_type=None)\n", " if orig_cache is None:\n", " _, orig_cache = model.run_with_cache(orig_dataset.toks, names_filter=z_name_filter, return_type=None)\n", "\n", " # Looping over every possible sender head (the receiver is always the final resid_post)\n", " for sender_layer, sender_head in tqdm(list(product(range(model.cfg.n_layers), range(model.cfg.n_heads)))):\n", " # ========== Step 2 ==========\n", " # Run on x_orig, with sender head patched from x_new, every other head frozen\n", "\n", " hook_fn = partial(\n", " patch_or_freeze_head_vectors,\n", " new_cache=new_cache,\n", " orig_cache=orig_cache,\n", " head_to_patch=(sender_layer, sender_head),\n", " )\n", " model.add_hook(z_name_filter, hook_fn)\n", "\n", " _, patched_cache = model.run_with_cache(\n", " orig_dataset.toks, names_filter=resid_post_name_filter, return_type=None\n", " )\n", " # if (sender_layer, sender_head) == (9, 9):\n", " # return patched_cache\n", " assert set(patched_cache.keys()) == {resid_post_hook_name}\n", "\n", " # ========== Step 3 ==========\n", " # Unembed the final residual stream value, to get our patched logits\n", "\n", " patched_logits = model.unembed(model.ln_final(patched_cache[resid_post_hook_name]))\n", "\n", " # Save the results\n", " results[sender_layer, sender_head] = patching_metric(patched_logits)\n", "\n", " return results\n", "\n", "\n", "path_patch_head_to_final_resid_post = get_path_patch_head_to_final_resid_post(model, ioi_metric_2)\n", "```\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "flvXwYMKHu9M" }, "source": [ "What is the interpretation of this plot? How does it compare to the equivalent plot we got from activation patching? (Remember that our metric is defined in a different way, so we should expect a sign difference between the two results.)\n", "\n", "
\n", "Some thoughts\n", "\n", "This plot is actually almost identical to the one we got from activation patching (apart from the results being negated, because of the new metric).\n", "\n", "This makes sense; the only reason activation patching would do something different to path patching is if the heads writing in the `Mary - John` direction had their outputs used by a later head (because this would be accounted for in activation patching, whereas path patching isolates the direct effect on the residual stream only). Since attention heads' primary purpose is to move information around the model, it's reasonable to guess that this probably isn't happening.\n", "\n", "Don't worry though, in the next set of exercises we'll do some more interesting path patching, and we'll get some results which are meaningfully different from our activation patching results.\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "JbwRZFbzHu9M" }, "source": [ "## Path Patching: S-Inhibition Heads" ] }, { "cell_type": "markdown", "metadata": { "id": "Dlu7zyiaHu9M" }, "source": [ "In the first section on path patching, we performed a simple kind of patching - from the output of an attention head to the final value of the residual stream. Here we'll do something a bit more interesting, and patch from the output of one head to the input of a later head. The purpose of this is to examine exactly how two heads are composing, and what effect the composed heads have on the model's output.\n", "\n", "We got a hint of this in the previous section, where we patched the values of the S-inhibition heads and found that they were important. But this didn't tell us which inputs to these value vectors were important; we had to make educated guesses about this based on our analysis earlier parts of the model. In path patching, we can perform a more precise test to find which heads are important.\n", "\n", "The paper's results from path patching are shown in figure 5(b), on page 7." ] }, { "cell_type": "markdown", "metadata": { "id": "tB2ICW9LHu9M" }, "source": [ "### Exercise - implement path patching from head to head\n", "\n", "> ```yaml\n", "> Difficulty: 🔴🔴🔴⚪⚪\n", "> Importance: 🔵🔵🔵⚪⚪\n", ">\n", "> You should spend up to 20-25 minutes on this exercise.\n", ">\n", "> You'll need a new hook function, but copying code from the previous exercise should make this one easier.\n", "> ```\n", "\n", "You should fill in the function `get_path_patch_head_to_head` below. It takes as arguments a list of receiver nodes (as well as the type of input - keys, queries, or values), and returns a tensor of shape\\* `(layer, head)` where each element is the result of running the patching metric on the output of the model, after applying the 3-step path patching algorithm from one of the model's heads to all the receiver heads. You should be able to replicate the paper's results (figure 5(b)).\n", "\n", "\\**Actually, you don't need to return all layers, because the causal effect from any sender head which is on the same or a later layer than the last of your receiver heads will necessarily be zero.*\n", "\n", "If you want a bit more guidance, you can use the dropdown below to see the ways in which this function should be different from your first path patching function (in most ways these functions will be similar, so you can start by copying that function).\n", "\n", "
\n", "Differences from first path patching function\n", "\n", "Step 1 is identical in both - gather all the observations.\n", "\n", "Step 2 is very similar. The only difference is that you'll be caching a different set of activations (your receiver heads).\n", "\n", "In section 3, since your receiver nodes are in the middle of the model rather than at the very end, you will have to run the model again with these nodes patched in rather than just calculating the logit output directly from the patched values of the final residual stream. To do this, you'll have to write a new hook function to patch in the inputs to an attention head (if you haven't done this already).\n", "
" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GX6_NWnzHu9M" }, "outputs": [], "source": [ "def patch_head_input(\n", " orig_activation: Float[Tensor, \"batch pos head_idx d_head\"],\n", " hook: HookPoint,\n", " patched_cache: ActivationCache,\n", " head_list: list[tuple[int, int]],\n", ") -> Float[Tensor, \"batch pos head_idx d_head\"]:\n", " \"\"\"\n", " Function which can patch any combination of heads in layers,\n", " according to the heads in head_list.\n", " \"\"\"\n", " heads_to_patch = [head for layer, head in head_list if layer == hook.layer()]\n", " orig_activation[:, :, heads_to_patch] = patched_cache[hook.name][:, :, heads_to_patch]\n", " return orig_activation\n", "\n", "\n", "def get_path_patch_head_to_heads(\n", " receiver_heads: list[tuple[int, int]],\n", " receiver_input: str,\n", " model: HookedTransformer,\n", " patching_metric: Callable,\n", " new_dataset: IOIDataset = abc_dataset,\n", " orig_dataset: IOIDataset = ioi_dataset,\n", " new_cache: ActivationCache | None = None,\n", " orig_cache: ActivationCache | None = None,\n", ") -> Float[Tensor, \"layer head\"]:\n", " \"\"\"\n", " Performs path patching (see algorithm in appendix B of IOI paper), with:\n", "\n", " sender head = (each head, looped through, one at a time)\n", " receiver node = input to a later head (or set of heads)\n", "\n", " The receiver node is specified by receiver_heads and receiver_input, for example if receiver_input = \"v\" and\n", " receiver_heads = [(8, 6), (8, 10), (7, 9), (7, 3)], we're doing path patching from each head to the value inputs of\n", " the S-inhibition heads.\n", "\n", " Returns:\n", " tensor of metric values for every possible sender head\n", " \"\"\"\n", " model.reset_hooks()\n", "\n", " raise NotImplementedError()\n", "\n", "\n", "model.reset_hooks()\n", "\n", "s_inhibition_value_path_patching_results = get_path_patch_head_to_heads(\n", " receiver_heads=[(8, 6), (8, 10), (7, 9), (7, 3)], receiver_input=\"v\", model=model, patching_metric=ioi_metric_2\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fHhcnPLLHu9M" }, "outputs": [], "source": [ "imshow(\n", " 100 * s_inhibition_value_path_patching_results,\n", " title=\"Direct effect on S-Inhibition Heads' values\",\n", " labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"Logit diff.
variation\"},\n", " width=600,\n", " coloraxis=dict(colorbar_ticksuffix=\"%\"),\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "1ubZ1eifHu9M" }, "source": [ "
\n", "Question - what is the interpretation of this plot? \n", "\n", "This plot confirms our earlier observations, that the S-inhibition heads' value vectors are the ones which matter. But it does more, by confirming our hypothesis that the S-inhibition heads' value vectors are supplied to them primarily by the outputs of heads `0.1`, `3.0`, `5.5` and `6.9` (which are the heads found by the paper to be the two most important duplicate token heads and two most important induction heads respectively).\n", "
\n", "\n", "\n", "
Solution\n", "\n", "```python\n", "def patch_head_input(\n", " orig_activation: Float[Tensor, \"batch pos head_idx d_head\"],\n", " hook: HookPoint,\n", " patched_cache: ActivationCache,\n", " head_list: list[tuple[int, int]],\n", ") -> Float[Tensor, \"batch pos head_idx d_head\"]:\n", " \"\"\"\n", " Function which can patch any combination of heads in layers,\n", " according to the heads in head_list.\n", " \"\"\"\n", " heads_to_patch = [head for layer, head in head_list if layer == hook.layer()]\n", " orig_activation[:, :, heads_to_patch] = patched_cache[hook.name][:, :, heads_to_patch]\n", " return orig_activation\n", "\n", "\n", "def get_path_patch_head_to_heads(\n", " receiver_heads: list[tuple[int, int]],\n", " receiver_input: str,\n", " model: HookedTransformer,\n", " patching_metric: Callable,\n", " new_dataset: IOIDataset = abc_dataset,\n", " orig_dataset: IOIDataset = ioi_dataset,\n", " new_cache: ActivationCache | None = None,\n", " orig_cache: ActivationCache | None = None,\n", ") -> Float[Tensor, \"layer head\"]:\n", " \"\"\"\n", " Performs path patching (see algorithm in appendix B of IOI paper), with:\n", "\n", " sender head = (each head, looped through, one at a time)\n", " receiver node = input to a later head (or set of heads)\n", "\n", " The receiver node is specified by receiver_heads and receiver_input, for example if receiver_input = \"v\" and\n", " receiver_heads = [(8, 6), (8, 10), (7, 9), (7, 3)], we're doing path patching from each head to the value inputs of\n", " the S-inhibition heads.\n", "\n", " Returns:\n", " tensor of metric values for every possible sender head\n", " \"\"\"\n", " model.reset_hooks()\n", "\n", " assert receiver_input in (\"k\", \"q\", \"v\")\n", " receiver_layers = set(next(zip(*receiver_heads)))\n", " receiver_hook_names = [utils.get_act_name(receiver_input, layer) for layer in receiver_layers]\n", " receiver_hook_names_filter = lambda name: name in receiver_hook_names\n", "\n", " results = t.zeros(max(receiver_layers), model.cfg.n_heads, device=device, dtype=t.float32)\n", "\n", " # ========== Step 1 ==========\n", " # Gather activations on x_orig and x_new\n", "\n", " # Note the use of names_filter for the run_with_cache function. Using it means we\n", " # only cache the things we need (in this case, just attn head outputs).\n", " z_name_filter = lambda name: name.endswith(\"z\")\n", " if new_cache is None:\n", " _, new_cache = model.run_with_cache(new_dataset.toks, names_filter=z_name_filter, return_type=None)\n", " if orig_cache is None:\n", " _, orig_cache = model.run_with_cache(orig_dataset.toks, names_filter=z_name_filter, return_type=None)\n", "\n", " # Note, the sender layer will always be before the final receiver layer, otherwise there will\n", " # be no causal effect from sender -> receiver. So we only need to loop this far.\n", " for sender_layer, sender_head in tqdm(list(product(range(max(receiver_layers)), range(model.cfg.n_heads)))):\n", " # ========== Step 2 ==========\n", " # Run on x_orig, with sender head patched from x_new, every other head frozen\n", "\n", " hook_fn = partial(\n", " patch_or_freeze_head_vectors,\n", " new_cache=new_cache,\n", " orig_cache=orig_cache,\n", " head_to_patch=(sender_layer, sender_head),\n", " )\n", " model.add_hook(z_name_filter, hook_fn, level=1)\n", "\n", " _, patched_cache = model.run_with_cache(\n", " orig_dataset.toks, names_filter=receiver_hook_names_filter, return_type=None\n", " )\n", " # model.reset_hooks(including_permanent=True)\n", " assert set(patched_cache.keys()) == set(receiver_hook_names)\n", "\n", " # ========== Step 3 ==========\n", " # Run on x_orig, patching in the receiver node(s) from the previously cached value\n", "\n", " hook_fn = partial(\n", " patch_head_input,\n", " patched_cache=patched_cache,\n", " head_list=receiver_heads,\n", " )\n", " patched_logits = model.run_with_hooks(\n", " orig_dataset.toks, fwd_hooks=[(receiver_hook_names_filter, hook_fn)], return_type=\"logits\"\n", " )\n", "\n", " # Save the results\n", " results[sender_layer, sender_head] = patching_metric(patched_logits)\n", "\n", " return results\n", "\n", "\n", "model.reset_hooks()\n", "\n", "s_inhibition_value_path_patching_results = get_path_patch_head_to_heads(\n", " receiver_heads=[(8, 6), (8, 10), (7, 9), (7, 3)], receiver_input=\"v\", model=model, patching_metric=ioi_metric_2\n", ")\n", "```\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "Hv-RBTmLHu9N" }, "source": [ "# 5️⃣ Full Replication: Minimial Circuits and more\n", "\n", "> ##### Learning Objectives\n", ">\n", "> * Replicate most of the other results from the [IOI paper](https://arxiv.org/abs/2211.00593)\n", "> * Practice more open-ended, less guided coding" ] }, { "cell_type": "markdown", "metadata": { "id": "JOgoOnd0Hu9N" }, "source": [ "This section will be a lot more open-ended and challenging. You'll be given somewhat less guidance in the exercises." ] }, { "cell_type": "markdown", "metadata": { "id": "EY52wpXTHu9N" }, "source": [ "## Copying & writing direction results" ] }, { "cell_type": "markdown", "metadata": { "id": "clPJXWfXHu9N" }, "source": [ "We'll start this section by replicating the paper's analysis of the **name mover heads** and **negative name mover heads**. Our previous analysis should have pretty much convinced us that these heads are copying / negatively copying our indirect object token, but the results here show this with a bit more rigour." ] }, { "cell_type": "markdown", "metadata": { "id": "TkzMyQiwHu9N" }, "source": [ "### Exercise - replicate writing direction results\n", "\n", "> ```yaml\n", "> Difficulty: 🔴🔴🔴🔴⚪\n", "> Importance: 🔵🔵⚪⚪⚪\n", ">\n", "> You should spend up to 20-25 minutes on this exercise.\n", ">\n", "> These exercises are much more challenging than they are conceptually important.\n", "> ```\n", "\n", "Let's look at figure 3(c) from the paper. This plots the output of the strongest name mover and negative name mover heads against the attention probabilities for `END` attending to `IO` or `S` (color-coded).\n", "\n", "Some clarifications:\n", "* \"Projection\" here is being used synonymously with \"dot product\".\n", "* We're projecting onto the name embedding, i.e. the embedding vector for the token which is being paid attention to. This is not the same as the logit diff (which we got by projecting the heads' output onto the difference between the unembedding vectors for `IO` and `S`).\n", " * We're doing this because the question we're trying to answer is *\"does the attention head copy (or anti-copy) the names which it pays attention to?\"*" ] }, { "cell_type": "markdown", "metadata": { "id": "QRRRxXAAHu9N" }, "source": [ "You should write code to replicate the paper's results in the cells below. Given four 1D tensors storing the results for a particular head (i.e. the projections and attention probabilities, for the `IO` and `S` tokens respectively), we've given you code to generate a plot which looks like the one in the paper. Again, you'll know that your code has worked if you can get results that resemble those in the paper." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "m4bhfzT0Hu9N" }, "outputs": [], "source": [ "def scatter_embedding_vs_attn(\n", " attn_from_end_to_io: Float[Tensor, \"batch\"],\n", " attn_from_end_to_s: Float[Tensor, \"batch\"],\n", " projection_in_io_dir: Float[Tensor, \"batch\"],\n", " projection_in_s_dir: Float[Tensor, \"batch\"],\n", " layer: int,\n", " head: int,\n", "):\n", " scatter(\n", " x=t.concat([attn_from_end_to_io, attn_from_end_to_s], dim=0),\n", " y=t.concat([projection_in_io_dir, projection_in_s_dir], dim=0),\n", " color=[\"IO\"] * N + [\"S\"] * N,\n", " title=f\"Projection of the output of {layer}.{head} along the name
embedding vs attention probability on name\",\n", " title_x=0.5,\n", " labels={\"x\": \"Attn prob on name\", \"y\": \"Dot w Name Embed\", \"color\": \"Name type\"},\n", " color_discrete_sequence=[\"#72FF64\", \"#C9A5F7\"],\n", " width=650,\n", " )\n", "\n", "\n", "def calculate_and_show_scatter_embedding_vs_attn(\n", " layer: int,\n", " head: int,\n", " cache: ActivationCache = ioi_cache,\n", " dataset: IOIDataset = ioi_dataset,\n", ") -> None:\n", " \"\"\"\n", " Creates and plots a figure equivalent to 3(c) in the paper.\n", "\n", " This should involve computing the four 1D tensors:\n", " attn_from_end_to_io\n", " attn_from_end_to_s\n", " projection_in_io_dir\n", " projection_in_s_dir\n", " and then calling the scatter_embedding_vs_attn function.\n", " \"\"\"\n", " raise NotImplementedError()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uobnxbNQHu9N" }, "outputs": [], "source": [ "calculate_and_show_scatter_embedding_vs_attn(9, 9) # name mover head 9.9\n", "\n", "calculate_and_show_scatter_embedding_vs_attn(11, 10) # negative name mover head 11.10" ] }, { "cell_type": "markdown", "metadata": { "id": "8M9jHR3YHu9N" }, "source": [ "
Solution\n", "\n", "```python\n", "def scatter_embedding_vs_attn(\n", " attn_from_end_to_io: Float[Tensor, \"batch\"],\n", " attn_from_end_to_s: Float[Tensor, \"batch\"],\n", " projection_in_io_dir: Float[Tensor, \"batch\"],\n", " projection_in_s_dir: Float[Tensor, \"batch\"],\n", " layer: int,\n", " head: int,\n", "):\n", " scatter(\n", " x=t.concat([attn_from_end_to_io, attn_from_end_to_s], dim=0),\n", " y=t.concat([projection_in_io_dir, projection_in_s_dir], dim=0),\n", " color=[\"IO\"] * N + [\"S\"] * N,\n", " title=f\"Projection of the output of {layer}.{head} along the name
embedding vs attention probability on name\",\n", " title_x=0.5,\n", " labels={\"x\": \"Attn prob on name\", \"y\": \"Dot w Name Embed\", \"color\": \"Name type\"},\n", " color_discrete_sequence=[\"#72FF64\", \"#C9A5F7\"],\n", " width=650,\n", " )\n", "\n", "\n", "def calculate_and_show_scatter_embedding_vs_attn(\n", " layer: int,\n", " head: int,\n", " cache: ActivationCache = ioi_cache,\n", " dataset: IOIDataset = ioi_dataset,\n", ") -> None:\n", " \"\"\"\n", " Creates and plots a figure equivalent to 3(c) in the paper.\n", "\n", " This should involve computing the four 1D tensors:\n", " attn_from_end_to_io\n", " attn_from_end_to_s\n", " projection_in_io_dir\n", " projection_in_s_dir\n", " and then calling the scatter_embedding_vs_attn function.\n", " \"\"\"\n", " # Get the value written to the residual stream at the end token by this head\n", " z = cache[utils.get_act_name(\"z\", layer)][:, :, head] # [batch seq d_head]\n", " N = z.size(0)\n", " output = z @ model.W_O[layer, head] # [batch seq d_model]\n", " output_on_end_token = output[t.arange(N), dataset.word_idx[\"end\"]] # [batch d_model]\n", "\n", " # Get the directions we'll be projecting onto\n", " io_unembedding = model.W_U.T[dataset.io_tokenIDs] # [batch d_model]\n", " s_unembedding = model.W_U.T[dataset.s_tokenIDs] # [batch d_model]\n", "\n", " # Get the value of projections, by multiplying and summing over the d_model dimension\n", " projection_in_io_dir = (output_on_end_token * io_unembedding).sum(-1) # [batch]\n", " projection_in_s_dir = (output_on_end_token * s_unembedding).sum(-1) # [batch]\n", "\n", " # Get attention probs, and index to get the probabilities from END -> IO / S\n", " attn_probs = cache[\"pattern\", layer][:, head] # [batch seqQ seqK]\n", " attn_from_end_to_io = attn_probs[t.arange(N), dataset.word_idx[\"end\"], dataset.word_idx[\"IO\"]] # [batch]\n", " attn_from_end_to_s = attn_probs[t.arange(N), dataset.word_idx[\"end\"], dataset.word_idx[\"S1\"]] # [batch]\n", "\n", " # Show scatter plot\n", " scatter_embedding_vs_attn(\n", " attn_from_end_to_io, attn_from_end_to_s, projection_in_io_dir, projection_in_s_dir, layer, head\n", " )\n", "\n", "```\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "K0sWzv5wHu9N" }, "source": [ "### Exercise - replicate copying score results\n", "\n", "> ```yaml\n", "> Difficulty: 🔴🔴🔴🔴🔴\n", "> Importance: 🔵🔵⚪⚪⚪\n", ">\n", "> You should spend up to 30-40 minutes on this exercise.\n", ">\n", "> These exercises are much more challenging than they are conceptually important.\n", "> ```\n", "\n", "Now let's do a different kind of test of the name mover heads' copying, by looking directly at the OV circuits.\n", "\n", "From page 6 of the paper:\n", "\n", "> To check that the Name Mover Heads copy names generally, we studied what values are written via the heads’ OV matrix. Specifically, we first obtained the state of the residual stream at the position of each name token after the first MLP layer. Then, we multiplied this by the OV matrix of a Name Mover Head (simulating what would happen if the head attended perfectly to that token), multiplied by the unembedding matrix, and applied the final layer norm to obtain logit probabilities. We compute the proportion of samples that contain the input name token in the top 5 logits (N = 1000) and call this the copy score. All three Name Mover Heads have a copy score above 95%, compared to less than 20% for an average head.\n", ">\n", "> Negative Name Mover Heads ... have a large negative copy score–the copy score calculated with the negative of the OV matrix (98% compared to 12% for an average head).\n", "\n", "Note the similarity between their method and how we studied copying in induction heads, during an earlier set of exercises. However, there are differences (e.g. we're only looking at whether the head copies names, rather than whether it copies tokens in general)." ] }, { "cell_type": "markdown", "metadata": { "id": "dyjgTvcfHu9N" }, "source": [ "You should replicate these results by completing the `get_copying_scores` function below.\n", "\n", "You could do this by indexing from the `ioi_cache`, but a much more principled alternative would be to embed all the names in the `NAMES` list and apply operations like MLPs, layernorms and OV matrices manually. This is what the solutions do.\n", "\n", "A few notes:\n", "\n", "* You can use `model.to_tokens` to convert the names to tokens. Remember to use `prepend_bos=False`, since you just want the tokens of names so you can embed them. Note that this function will treat the list of names as a batch of single-token inputs, which works fine for our purposes.\n", "* You can apply MLPs and layernorms as functions, by just indexing the model's blocks (e.g. use `model.blocks[i].mlp` or `model.blocks[j].ln1` as a function). Remember that `ln1` is the layernorm that comes before attention, and `ln2` comes before the MLP.\n", "* Remember that you need to apply MLP0 before you apply the OV matrix (which is why we omit the 0th layer in our scores). The reason for this is that ablating MLP0 has a strangely large effect in gpt2-small relative to ablating other MLPs, possibly because it's acting as an extended embedding (see [here](https://www.lesswrong.com/s/yivyHaCAmMJ3CqSyj/p/XNjRwEX9kxbpzWFWd#:~:text=GPT%2D2%20Small%E2%80%99s%20performance%20is%20ruined%20if%20you%20ablate%20MLP0) for an explanation).\n", "\n", "Also, you shouldn't expect to get exactly the same results as the paper (because some parts of this experiment have been set up very slightly different), although you probably shouldn't be off by more than 10%." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "SG2IZ5O9Hu9N" }, "outputs": [], "source": [ "def get_copying_scores(model: HookedTransformer, k: int = 5, names: list = NAMES) -> Float[Tensor, \"2 layer-1 head\"]:\n", " \"\"\"\n", " Gets copying scores (both positive and negative) as described in page 6 of the IOI paper, for every (layer, head)\n", " pair in the model.\n", "\n", " Returns these in a 3D tensor (the first dimension is for positive vs negative).\n", "\n", " Omits the 0th layer, because this is before MLP0 (which we're claiming acts as an extended embedding).\n", " \"\"\"\n", " raise NotImplementedError()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QzI502xjHu9N" }, "outputs": [], "source": [ "copying_results = get_copying_scores(model)\n", "\n", "imshow(\n", " copying_results,\n", " facet_col=0,\n", " facet_labels=[\"Positive copying scores\", \"Negative copying scores\"],\n", " title=\"Copying scores of attention heads' OV circuits\",\n", " width=800\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-pzXKpe6Hu9N" }, "outputs": [], "source": [ "heads = {\"name mover\": [(9, 9), (10, 0), (9, 6)], \"negative name mover\": [(10, 7), (11, 10)]}\n", "\n", "for i, name in enumerate([\"name mover\", \"negative name mover\"]):\n", " make_table(\n", " title=f\"Copying Scores ({name} heads)\",\n", " colnames=[\"Head\", \"Score\"],\n", " cols=[\n", " list(map(str, heads[name])) + [\"[dark_orange bold]Average\"],\n", " [f\"{copying_results[i, layer-1, head]:.2%}\" for (layer, head) in heads[name]] + [f\"[dark_orange bold]{copying_results[i].mean():.2%}\"]\n", " ]\n", " )" ] }, { "cell_type": "markdown", "metadata": { "id": "KhdfNDbNHu9N" }, "source": [ "
Solution\n", "\n", "```python\n", "def get_copying_scores(model: HookedTransformer, k: int = 5, names: list = NAMES) -> Float[Tensor, \"2 layer-1 head\"]:\n", " \"\"\"\n", " Gets copying scores (both positive and negative) as described in page 6 of the IOI paper, for every (layer, head)\n", " pair in the model.\n", "\n", " Returns these in a 3D tensor (the first dimension is for positive vs negative).\n", "\n", " Omits the 0th layer, because this is before MLP0 (which we're claiming acts as an extended embedding).\n", " \"\"\"\n", " results = t.zeros((2, model.cfg.n_layers, model.cfg.n_heads), device=device)\n", "\n", " # Define components from our model (for typechecking, and cleaner code)\n", " embed: Embed = model.embed\n", " mlp0: MLP = model.blocks[0].mlp\n", " ln0: LayerNorm = model.blocks[0].ln2\n", " unembed: Unembed = model.unembed\n", " ln_final: LayerNorm = model.ln_final\n", "\n", " # Get embeddings for the names in our list\n", " name_tokens: Int[Tensor, \"batch 1\"] = model.to_tokens(names, prepend_bos=False)\n", " name_embeddings: Int[Tensor, \"batch 1 d_model\"] = embed(name_tokens)\n", "\n", " # Get residual stream after applying MLP\n", " resid_after_mlp1 = name_embeddings + mlp0(ln0(name_embeddings))\n", "\n", " # Loop over all (layer, head) pairs\n", " for layer in range(1, model.cfg.n_layers):\n", " for head in range(model.cfg.n_heads):\n", " # Get W_OV matrix\n", " W_OV = model.W_V[layer, head] @ model.W_O[layer, head]\n", "\n", " # Get residual stream after applying W_OV or -W_OV respectively\n", " # (note, because of bias b_U, it matters that we do sign flip here, not later)\n", " resid_after_OV_pos = resid_after_mlp1 @ W_OV\n", " resid_after_OV_neg = resid_after_mlp1 @ -W_OV\n", "\n", " # Get logits from value of residual stream\n", " logits_pos = unembed(ln_final(resid_after_OV_pos)).squeeze() # [batch d_vocab]\n", " logits_neg = unembed(ln_final(resid_after_OV_neg)).squeeze() # [batch d_vocab]\n", "\n", " # Check how many are in top k\n", " topk_logits: Int[Tensor, \"batch k\"] = t.topk(logits_pos, dim=-1, k=k).indices\n", " in_topk = (topk_logits == name_tokens).any(-1)\n", " # Check how many are in bottom k\n", " bottomk_logits: Int[Tensor, \"batch k\"] = t.topk(logits_neg, dim=-1, k=k).indices\n", " in_bottomk = (bottomk_logits == name_tokens).any(-1)\n", "\n", " # Fill in results\n", " results[:, layer - 1, head] = t.tensor([in_topk.float().mean(), in_bottomk.float().mean()])\n", "\n", " return results\n", "\n", "\n", "copying_results = get_copying_scores(model)\n", "\n", "imshow(\n", " copying_results,\n", " facet_col=0,\n", " facet_labels=[\"Positive copying scores\", \"Negative copying scores\"],\n", " title=\"Copying scores of attention heads' OV circuits\",\n", " width=900,\n", ")\n", "\n", "heads = {\"name mover\": [(9, 9), (10, 0), (9, 6)], \"negative name mover\": [(10, 7), (11, 10)]}\n", "\n", "for i, name in enumerate([\"name mover\", \"negative name mover\"]):\n", " make_table(\n", " title=f\"Copying Scores ({name} heads)\",\n", " colnames=[\"Head\", \"Score\"],\n", " cols=[\n", " list(map(str, heads[name])) + [\"[dark_orange bold]Average\"],\n", " [f\"{copying_results[i, layer - 1, head]:.2%}\" for (layer, head) in heads[name]]\n", " + [f\"[dark_orange bold]{copying_results[i].mean():.2%}\"],\n", " ],\n", " )\n", "```\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "yuIbfcPGHu9N" }, "source": [ "## Validation of early heads" ] }, { "cell_type": "markdown", "metadata": { "id": "iTH1gkxhHu9N" }, "source": [ "There are three different kinds of heads which appear early in the circuit, which can be validated by looking at their attention patterns on simple random sequences of tokens. Can you figure out which three types these are, and how to validate them in this way?\n", "\n", "
\n", "Answer\n", "\n", "Previous token heads, induction heads, and duplicate token heads.\n", "\n", "We can validate them all at the same time, using sequences of `n` random tokens followed by those same `n` random tokens repeated. This works as follows:\n", "\n", "* Prev token heads, by measuring the attention patterns with an offset of one (i.e. one below the diagonal).\n", "* Induction heads, by measuring the attention patterns with an offset of `n-1` (i.e. the second instance of a token paying attention to the token after its first instance).\n", "* Duplicate token heads, by measuring the attention patterns with an offset of `n` (i.e. a token paying attention to its previous instance).\n", "\n", "In all three cases, if heads score close to 1 on these metrics, it's strong evidence that they are working as this type of head.\n", "\n", "Note, it's a leaky abstraction to say things like \"head X is an induction head\", since we're only observing it on a certain distribution. For instance, it's not clear what the role of induction heads and duplicate token heads is when there are no duplicates (they could in theory do something completely different).\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "anZPQgiMHu9N" }, "source": [ "### Exercise - perform head validation\n", "\n", "> ```yaml\n", "> Difficulty: 🔴🔴🔴⚪⚪\n", "> Importance: 🔵🔵🔵⚪⚪\n", ">\n", "> You should spend up to 20-30 minutes on this exercise.\n", ">\n", "> Understanding how to identify certain types of heads by their characteristic attention patterns is important.\n", "> ```\n", "\n", "Once you've read the answer in the dropdown above, you should perform this validation. The result should be a replication of Figure 18 in the paper.\n", "\n", "We've provided a template for this function. Note use of `typing.Literal`, which is how we indicate that the argument should be one of the following options.\n", "\n", "We've also provided a helper function `generate_repeated_tokens` (which is similar to the one you used from exercise set 1.2, except that it has no start token, to match the paper), and a helper function `plot_early_head_validation_results` which calls the `get_attn_scores` function and plots the results (in a way which should look like Figure 18). So it's just the `get_attn_scores` function that you need to fill in." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "T8CYRcyLHu9N" }, "outputs": [], "source": [ "def generate_repeated_tokens(\n", " model: HookedTransformer, seq_len: int, batch: int = 1\n", ") -> Float[Tensor, \"batch 2*seq_len\"]:\n", " \"\"\"\n", " Generates a sequence of repeated random tokens (no start token).\n", " \"\"\"\n", " rep_tokens_half = t.randint(0, model.cfg.d_vocab, (batch, seq_len), dtype=t.int64)\n", " rep_tokens = t.cat([rep_tokens_half, rep_tokens_half], dim=-1).to(device)\n", " return rep_tokens\n", "\n", "\n", "def get_attn_scores(\n", " model: HookedTransformer, seq_len: int, batch: int, head_type: Literal[\"duplicate\", \"prev\", \"induction\"]\n", ") -> Float[Tensor, \"n_layers n_heads\"]:\n", " \"\"\"\n", " Returns attention scores for sequence of duplicated tokens, for every head.\n", " \"\"\"\n", " raise NotImplementedError()\n", "\n", "\n", "def plot_early_head_validation_results(seq_len: int = 50, batch: int = 50):\n", " \"\"\"\n", " Produces a plot that looks like Figure 18 in the paper.\n", " \"\"\"\n", " head_types = [\"duplicate\", \"prev\", \"induction\"]\n", "\n", " results = t.stack([get_attn_scores(model, seq_len, batch, head_type=head_type) for head_type in head_types])\n", "\n", " imshow(\n", " results,\n", " facet_col=0,\n", " facet_labels=[\n", " f\"{head_type.capitalize()} token attention prob.
on sequences of random tokens\"\n", " for head_type in head_types\n", " ],\n", " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", " width=1300,\n", " )" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xoRIhZFfHu9N" }, "outputs": [], "source": [ "model.reset_hooks()\n", "plot_early_head_validation_results()" ] }, { "cell_type": "markdown", "metadata": { "id": "hVoT-kL9Hu9N" }, "source": [ "
Solution\n", "\n", "```python\n", "def generate_repeated_tokens(\n", " model: HookedTransformer, seq_len: int, batch: int = 1\n", ") -> Float[Tensor, \"batch 2*seq_len\"]:\n", " \"\"\"\n", " Generates a sequence of repeated random tokens (no start token).\n", " \"\"\"\n", " rep_tokens_half = t.randint(0, model.cfg.d_vocab, (batch, seq_len), dtype=t.int64)\n", " rep_tokens = t.cat([rep_tokens_half, rep_tokens_half], dim=-1).to(device)\n", " return rep_tokens\n", "\n", "\n", "def get_attn_scores(\n", " model: HookedTransformer, seq_len: int, batch: int, head_type: Literal[\"duplicate\", \"prev\", \"induction\"]\n", ") -> Float[Tensor, \"n_layers n_heads\"]:\n", " \"\"\"\n", " Returns attention scores for sequence of duplicated tokens, for every head.\n", " \"\"\"\n", " rep_tokens = generate_repeated_tokens(model, seq_len, batch)\n", "\n", " _, cache = model.run_with_cache(rep_tokens, return_type=None, names_filter=lambda name: name.endswith(\"pattern\"))\n", "\n", " # Get the right indices for the attention scores\n", "\n", " if head_type == \"duplicate\":\n", " src_indices = range(seq_len)\n", " dest_indices = range(seq_len, 2 * seq_len)\n", " elif head_type == \"prev\":\n", " src_indices = range(seq_len)\n", " dest_indices = range(1, seq_len + 1)\n", " elif head_type == \"induction\":\n", " dest_indices = range(seq_len, 2 * seq_len)\n", " src_indices = range(1, seq_len + 1)\n", "\n", " results = t.zeros(model.cfg.n_layers, model.cfg.n_heads, device=device, dtype=t.float32)\n", " for layer in range(model.cfg.n_layers):\n", " for head in range(model.cfg.n_heads):\n", " attn_scores = cache[\"pattern\", layer] # [batch seqQ seqK]\n", " avg_attn_on_duplicates = attn_scores[:, head, dest_indices, src_indices].mean().item()\n", " results[layer, head] = avg_attn_on_duplicates\n", "\n", " return results\n", "\n", "\n", "def plot_early_head_validation_results(seq_len: int = 50, batch: int = 50):\n", " \"\"\"\n", " Produces a plot that looks like Figure 18 in the paper.\n", " \"\"\"\n", " head_types = [\"duplicate\", \"prev\", \"induction\"]\n", "\n", " results = t.stack([get_attn_scores(model, seq_len, batch, head_type=head_type) for head_type in head_types])\n", "\n", " imshow(\n", " results,\n", " facet_col=0,\n", " facet_labels=[\n", " f\"{head_type.capitalize()} token attention prob.
on sequences of random tokens\"\n", " for head_type in head_types\n", " ],\n", " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", " width=1300,\n", " )\n", "\n", "\n", "model.reset_hooks()\n", "plot_early_head_validation_results()\n", "```\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "SgibQ4XqHu9N" }, "source": [ "Note - these figures suggest that it would be a useful bit of infrastructure to have a \"wiki\" for the heads of a model, giving their scores according to some metrics re head functions, like the ones we've seen here. HookedTransformer makes this pretty easy to make, as just changing the name input to `HookedTransformer.from_pretrained` gives a different model but in the same architecture, so the same code should work. If you want to make this, I'd love to see it!\n", "\n", "As a proof of concept, [I made a mosaic of all induction heads across the 40 models then in HookedTransformer](https://www.neelnanda.io/mosaic)." ] }, { "cell_type": "markdown", "metadata": { "id": "uVSGojnWHu9N" }, "source": [ "## Minimal Circuit" ] }, { "cell_type": "markdown", "metadata": { "id": "W-tTxNpXHu9N" }, "source": [ "### Background: faithfulness, completeness, and minimality" ] }, { "cell_type": "markdown", "metadata": { "id": "o5QSksyxHu9N" }, "source": [ "The authors developed three criteria for validating their circuit explanations: faithfulness, completeness and minimality. They are defined as follows:\n", "\n", "* **Faithful** = the circuit can perform as well as the whole model\n", "* **Complete** = the circuit contains all nodes used to perform the task\n", "* **Minimal** = the circuit doesn't contain nodes irrelevant to the task\n", "\n", "If all three criteria are met, then the circuit is considered a reliable explanation for model behaviour.\n", "\n", "Exercise - can you understand why each of these criteria is important? For each pair of criteria, is it possible for a circuit to meet them both but fail the third (and if yes, can you give examples?).\n", "\n", "
\n", "Answer\n", "\n", "The naive circuit (containing the entire model) is trivially faithful and complete, but obviously not minimal. In general, the problem with non-minimal circuits is that they may not be mechanistically understandable, which defeats the purpose of this kind of circuit analysis.\n", "\n", "Completeness obviously implies faithfulness, because if a node isn't involved in the task, then it can't improve the model's performance on that task.\n", "\n", "You might initially think faithfulness implies completeness, but this actually isn't true. Backup name mover heads illustrate this point. They are used in the task, and without understanding the role they play you'll have an incorrect model of reality (e.g. you'll think ablating the name mover heads would destroy performance, which turns out not to be true). So if you define a circuit that doesn't contain backup name mover heads then it will be faithful (the backup name mover heads won't be used) but not complete.\n", "\n", "Summary:\n", "\n", "* **Faithful & complete, not minimal** = possible (example: naive circuit)\n", "* **Faithful & minimal, not complete** = possible (example: circuit missing backup name mover heads)\n", "* **Complete & minimal, not faithful** = impossible (since completeness implies faithfulness)\n", "\n", "In the paper, the authors formalise these concepts. Faithfulness is equivalent to $|F(C) - F(M)|$ being small (where $C$ is our circuit, $M$ is the model, and $F$ is our performance metric function), and completeness is equivalent to $|F(C\\backslash K) - F(M\\backslash K)|$ being small **for any subset $K \\subset C$** (including when $K$ is the empty set, showing that completeness implies faithfulness). Can you see how our circuit missing backup name mover heads violates this condition?\n", "\n", "
\n", "Answer\n", "\n", "It violates the condition when $K$ is the set of name mover heads. $C \\backslash K$ performs worse than $M \\backslash K$, because the latter contains backup name mover heads while the former has lost its name mover heads ***and*** backup name mover heads.\n", "
\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "J2IqY7FQHu9N" }, "source": [ "Now that we've analysed most components of the circuit and we have a rough idea of how they work, the next step is to ablate everything except those core components and verify that the model still performs well.\n", "\n", "This ablation is pretty massive - we're ablating everything except for the output of each of our key attention heads (e.g. duplicate token heads or S-inhibition heads) at a single sequence position (e.g. for the DTHs this is the `S2` token, and for SIHs this is the `end` token). Given that our core circuit has 26 heads in total, and our sequences have length around 20 on average, this means we're ablating all but $(26 / 144) / 20 \\approx 1\\%$ of our attention heads' output (and the number of possible paths through the model is reduced by ***much*** more than this)." ] }, { "cell_type": "markdown", "metadata": { "id": "lTbY5OUJHu9O" }, "source": [ "How do we ablate? We could use zero-ablation, but this actually has some non-obvious problems. To explain why intuitively, heads might be \"expecting\" non-zero input, and setting the input to zero is essentially an arbitrary choice which takes it off distribution. You can think of this as adding a bias term to this head, which might mess up subsequent computation and lead to noisy results. We could also use mean-ablation (i.e. set a head's output to its average output over `ioi_dataset`), but the problem here is that taking the mean over this dataset might contain relevant information for solving the IOI task. For example the `is_duplicated` flag which gets written to `S2` will be present for all sequences, so averaging won't remove this information.\n", "\n", "Can you think of a way to solve this problem? After you've considered this, you can use the dropdown below to see how the authors handled this.\n", "\n", "
\n", "Answer\n", "\n", "We ablate with the mean of the ABC dataset rather than the IOI dataset. This removes the problem of averages still containing relevant information from solving the IOI task.\n", "\n", "
\n", "\n", "One other complication - the sentences have different templates, and the positions of tokens like `S` and `IO` are not consistent across these templates (we avoided this problem in previous exercises by choosing a very small set of sentences, where all the important tokens had the same indices). An example of two templates with different positions:\n", "\n", "```\n", "\"Then, [B] and [A] had a long argument and after that [B] said to [A]\"\n", "\"After the lunch [B] and [A] went to the [PLACE], and [B] gave a [OBJECT] to [A]\"\n", "```\n", "\n", "Can you guess what the authors did to solve this problem?\n", "\n", "
\n", "Answer\n", "\n", "They took the mean over each template rather than over the whole dataset, and used these values to ablate with.\n", "\n", "In other words, when they performed ablation by patching in the output of a head (which has shape `(batch, seq, d_model)`), the value patched into the `(i, j, k)`-th element of this tensor would be the average value of the `k`-th element of the vector at sequence position `j`, for all sentences with the same template as the `i`-th sentence in the batch.\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "u2KyE6OmHu9O" }, "source": [ "### Exercise - constructing the minimal circuit\n", "\n", "> ```yaml\n", "> Difficulty: 🔴🔴🔴🔴🔴\n", "> Importance: 🔵🔵⚪⚪⚪\n", ">\n", "> This exercise is expected to take a long time; at least an hour. It is probably the most challenging exercise in this notebook.\n", "> ```" ] }, { "cell_type": "markdown", "metadata": { "id": "5UGiIFwwHu9O" }, "source": [ "You now have enough information to perform ablation on your model, to get the minimal circuit. Below, you can try to implement this yourself.\n", "\n", "This exercise is very technically challenging, so you're welcome to skip it if it doesn't seem interesting to you. However, I recommend you have a read of the solution, to understand the rough contours of how this ablation works.\n", "\n", "If you want to attempt this task, then you can start with the code below. We define two dictionaries, one mapping head types to the heads in the model which are of that type, and the other mapping head types to the sequence positions which we *won't* be ablating for those types of head." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "C-EGwD2tHu9O" }, "outputs": [], "source": [ "CIRCUIT = {\n", " \"name mover\": [(9, 9), (10, 0), (9, 6)],\n", " \"backup name mover\": [(10, 10), (10, 6), (10, 2), (10, 1), (11, 2), (9, 7), (9, 0), (11, 9)],\n", " \"negative name mover\": [(10, 7), (11, 10)],\n", " \"s2 inhibition\": [(7, 3), (7, 9), (8, 6), (8, 10)],\n", " \"induction\": [(5, 5), (5, 8), (5, 9), (6, 9)],\n", " \"duplicate token\": [(0, 1), (0, 10), (3, 0)],\n", " \"previous token\": [(2, 2), (4, 11)],\n", "}\n", "\n", "SEQ_POS_TO_KEEP = {\n", " \"name mover\": \"end\",\n", " \"backup name mover\": \"end\",\n", " \"negative name mover\": \"end\",\n", " \"s2 inhibition\": \"end\",\n", " \"induction\": \"S2\",\n", " \"duplicate token\": \"S2\",\n", " \"previous token\": \"S1+1\",\n", "}" ] }, { "cell_type": "markdown", "metadata": { "id": "CxkHBPPQHu9O" }, "source": [ "To be clear, the things that we'll be mean-ablating are:\n", "\n", "* Every head not in the `CIRCUIT` dict\n", "* Every sequence position for the heads in `CIRCUIT` dict, except for the sequence positions given by the `SEQ_POS_TO_KEEP` dict\n", "\n", "And we'll be mean-ablating by replacing a head's output with the mean output for `abc_dataset`, over all sentences with the same template as the sentence in the batch. You can access the templates for a dataset using the `dataset.groups` attribute, which returns a list of tensors (each one containing the indices of sequences in the batch sharing the same template)." ] }, { "cell_type": "markdown", "metadata": { "id": "J4IyS7BVHu9O" }, "source": [ "Now, you can try and complete the following function, which should add a ***permanent hook*** to perform this ablation whenever the model is run on the `ioi_dataset` (note that this hook will only make sense if the model is run on this dataset, so we should reset hooks if we run it on a different dataset).\n", "\n", "Permanent hooks are a useful feature of transformerlens. They behave just like regular hooks, except they aren't removed when you run the model (e.g. using `model.run_with_cache` or `model.run_with_hooks`). The only way to remove them is with:\n", "\n", "```python\n", "model.reset_hooks(including_permanent=True)\n", "```\n", "\n", "You can add permanent hooks as follows:\n", "\n", "```python\n", "model.add_hook(hook_name, hook_fn, is_permanent=True)\n", "```\n", "\n", "where `hook_name` can be a string or a filter function mapping strings to booleans." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jkvQLTsnHu9O" }, "outputs": [], "source": [ "def add_mean_ablation_hook(\n", " model: HookedTransformer,\n", " means_dataset: IOIDataset,\n", " circuit: dict[str, list[tuple[int, int]]] = CIRCUIT,\n", " seq_pos_to_keep: dict[str, str] = SEQ_POS_TO_KEEP,\n", " is_permanent: bool = True,\n", ") -> HookedTransformer:\n", " \"\"\"\n", " Adds a permanent hook to the model, which ablates according to the circuit and seq_pos_to_keep dictionaries.\n", "\n", " In other words, when the model is run on ioi_dataset, every head's output will be replaced with the mean over\n", " means_dataset for sequences with the same template, except for a subset of heads and sequence positions as specified\n", " by the circuit and seq_pos_to_keep dicts.\n", " \"\"\"\n", " raise NotImplementedError()" ] }, { "cell_type": "markdown", "metadata": { "id": "VX9GBXs9Hu9P" }, "source": [ "
\n", "Hint (docstrings of some functions which will be useful for your main function)\n", "\n", "```python\n", "def compute_means_by_template(\n", " means_dataset: IOIDataset,\n", " model: HookedTransformer\n", ") -> Float[Tensor, \"layer batch seq head_idx d_head\"]:\n", " '''\n", " Returns the mean of each head's output over the means dataset. This mean is\n", " computed separately for each group of prompts with the same template (these\n", " are given by means_dataset.groups).\n", " '''\n", " pass\n", "\n", "\n", "def get_heads_and_posns_to_keep(\n", " means_dataset: IOIDataset,\n", " model: HookedTransformer,\n", " circuit: dict[str, list[tuple[int, int]]],\n", " seq_pos_to_keep: dict[str, str],\n", ") -> dict[int, Bool[Tensor, \"batch seq head\"]]:\n", " '''\n", " Returns a dictionary mapping layers to a boolean mask giving the indices of the\n", " z output which *shouldn't* be mean-ablated.\n", "\n", " The output of this function will be used for the hook function that does ablation.\n", " '''\n", " pass\n", "\n", "\n", "def hook_fn_mask_z(\n", " z: Float[Tensor, \"batch seq head d_head\"],\n", " hook: HookPoint,\n", " heads_and_posns_to_keep: dict[int, Bool[Tensor, \"batch seq head\"]],\n", " means: Float[Tensor, \"layer batch seq head d_head\"],\n", ") -> Float[Tensor, \"batch seq head d_head\"]:\n", " '''\n", " Hook function which masks the z output of a transformer head.\n", "\n", " heads_and_posns_to_keep\n", " dict created with the get_heads_and_posns_to_keep function. This tells\n", " us where to mask.\n", "\n", " means\n", " Tensor of mean z values of the means_dataset over each group of prompts\n", " with the same template. This tells us what values to mask with.\n", " '''\n", " pass\n", "```\n", "\n", "Once you fill in these three functions, completing the main function is simple. It should:\n", "\n", "* Use `compute_means_by_template` to get the means\n", "* Use `get_heads_and_posns_to_keep` to get the boolean mask\n", "* Apply `functools.partial` to `hook_fn_mask_z`, using the outputs of the 2 previous functions, to get your hook function which performs the mean ablation\n", "* Add this hook function to your model, as a permanent hook\n", "\n", "
\n", "\n", "\n", "
Solution\n", "\n", "```python\n", "def get_heads_and_posns_to_keep(\n", " means_dataset: IOIDataset,\n", " model: HookedTransformer,\n", " circuit: dict[str, list[tuple[int, int]]],\n", " seq_pos_to_keep: dict[str, str],\n", ") -> dict[int, Bool[Tensor, \"batch seq head\"]]:\n", " \"\"\"\n", " Returns a dictionary mapping layers to a boolean mask giving the indices of the z output which *shouldn't* be\n", " mean-ablated.\n", "\n", " The output of this function will be used for the hook function that does ablation.\n", " \"\"\"\n", " heads_and_posns_to_keep = {}\n", " batch, seq, n_heads = len(means_dataset), means_dataset.max_len, model.cfg.n_heads\n", "\n", " for layer in range(model.cfg.n_layers):\n", " mask = t.zeros(size=(batch, seq, n_heads))\n", "\n", " for head_type, head_list in circuit.items():\n", " seq_pos = seq_pos_to_keep[head_type]\n", " indices = means_dataset.word_idx[seq_pos]\n", " for layer_idx, head_idx in head_list:\n", " if layer_idx == layer:\n", " mask[:, indices, head_idx] = 1\n", "\n", " heads_and_posns_to_keep[layer] = mask.bool()\n", "\n", " return heads_and_posns_to_keep\n", "\n", "\n", "def hook_fn_mask_z(\n", " z: Float[Tensor, \"batch seq head d_head\"],\n", " hook: HookPoint,\n", " heads_and_posns_to_keep: dict[int, Bool[Tensor, \"batch seq head\"]],\n", " means: Float[Tensor, \"layer batch seq head d_head\"],\n", ") -> Float[Tensor, \"batch seq head d_head\"]:\n", " \"\"\"\n", " Hook function which masks the z output of a transformer head.\n", "\n", " heads_and_posns_to_keep\n", " dict created with the get_heads_and_posns_to_keep function. This tells us where to mask.\n", "\n", " means\n", " Tensor of mean z values of the means_dataset over each group of prompts with the same template. This tells us\n", " what values to mask with.\n", " \"\"\"\n", " # Get the mask for this layer, and add d_head=1 dimension so it broadcasts correctly\n", " mask_for_this_layer = heads_and_posns_to_keep[hook.layer()].unsqueeze(-1).to(z.device)\n", "\n", " # Set z values to the mean\n", " z = t.where(mask_for_this_layer, z, means[hook.layer()])\n", "\n", " return z\n", "\n", "\n", "def compute_means_by_template(\n", " means_dataset: IOIDataset, model: HookedTransformer\n", ") -> Float[Tensor, \"layer batch seq head_idx d_head\"]:\n", " \"\"\"\n", " Returns the mean of each head's output over the means dataset. This mean is computed separately for each group of\n", " prompts with the same template (these are given by means_dataset.groups).\n", " \"\"\"\n", " # Cache the outputs of every head\n", " _, means_cache = model.run_with_cache(\n", " means_dataset.toks.long(),\n", " return_type=None,\n", " names_filter=lambda name: name.endswith(\"z\"),\n", " )\n", " # Create tensor to store means\n", " n_layers, n_heads, d_head = model.cfg.n_layers, model.cfg.n_heads, model.cfg.d_head\n", " batch, seq_len = len(means_dataset), means_dataset.max_len\n", " means = t.zeros(size=(n_layers, batch, seq_len, n_heads, d_head), device=model.cfg.device)\n", "\n", " # Get set of different templates for this data\n", " for layer in range(model.cfg.n_layers):\n", " z_for_this_layer = means_cache[utils.get_act_name(\"z\", layer)] # [batch seq head d_head]\n", " for template_group in means_dataset.groups:\n", " z_for_this_template = z_for_this_layer[template_group]\n", " z_means_for_this_template = einops.reduce(\n", " z_for_this_template, \"batch seq head d_head -> seq head d_head\", \"mean\"\n", " )\n", " means[layer, template_group] = z_means_for_this_template\n", "\n", " return means\n", "\n", "\n", "def add_mean_ablation_hook(\n", " model: HookedTransformer,\n", " means_dataset: IOIDataset,\n", " circuit: dict[str, list[tuple[int, int]]] = CIRCUIT,\n", " seq_pos_to_keep: dict[str, str] = SEQ_POS_TO_KEEP,\n", " is_permanent: bool = True,\n", ") -> HookedTransformer:\n", " \"\"\"\n", " Adds a permanent hook to the model, which ablates according to the circuit and seq_pos_to_keep dictionaries.\n", "\n", " In other words, when the model is run on ioi_dataset, every head's output will be replaced with the mean over\n", " means_dataset for sequences with the same template, except for a subset of heads and sequence positions as specified\n", " by the circuit and seq_pos_to_keep dicts.\n", " \"\"\"\n", " model.reset_hooks(including_permanent=True)\n", "\n", " # Compute the mean of each head's output on the ABC dataset, grouped by template\n", " means = compute_means_by_template(means_dataset, model)\n", "\n", " # Convert this into a boolean map\n", " heads_and_posns_to_keep = get_heads_and_posns_to_keep(means_dataset, model, circuit, seq_pos_to_keep)\n", "\n", " # Get a hook function which will patch in the mean z values for each head, at\n", " # all positions which aren't important for the circuit\n", " hook_fn = partial(hook_fn_mask_z, heads_and_posns_to_keep=heads_and_posns_to_keep, means=means)\n", "\n", " # Apply hook\n", " model.add_hook(lambda name: name.endswith(\"z\"), hook_fn, is_permanent=is_permanent)\n", "\n", " return model\n", "```\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "lxz606gYHu9P" }, "source": [ "To test whether your function works, you can use the function provided to you, and see if the logit difference from your implementation of the circuit matches this one:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yGfnOD3iHu9P" }, "outputs": [], "source": [ "import part41_indirect_object_identification.ioi_circuit_extraction as ioi_circuit_extraction\n", "\n", "model = ioi_circuit_extraction.add_mean_ablation_hook(\n", " model,\n", " means_dataset=abc_dataset,\n", " circuit=CIRCUIT,\n", " seq_pos_to_keep=SEQ_POS_TO_KEEP,\n", ")\n", "ioi_logits_minimal = model(ioi_dataset.toks)\n", "\n", "print(f\"\"\"Average logit difference (IOI dataset, using entire model): {logits_to_ave_logit_diff_2(ioi_logits_original):.4f}\n", "Average logit difference (IOI dataset, only using circuit): {logits_to_ave_logit_diff_2(ioi_logits_minimal):.4f}\"\"\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_VN4JSPsHu9P" }, "outputs": [], "source": [ "model = add_mean_ablation_hook(\n", " model,\n", " means_dataset=abc_dataset,\n", " circuit=CIRCUIT,\n", " seq_pos_to_keep=SEQ_POS_TO_KEEP,\n", ")\n", "ioi_logits_minimal = model(ioi_dataset.toks)\n", "\n", "print(f\"\"\"Average logit difference (IOI dataset, using entire model): {logits_to_ave_logit_diff_2(ioi_logits_original):.4f}\n", "Average logit difference (IOI dataset, only using circuit): {logits_to_ave_logit_diff_2(ioi_logits_minimal):.4f}\"\"\")" ] }, { "cell_type": "markdown", "metadata": { "id": "W2YOz1NhHu9P" }, "source": [ "You should find that the logit difference only drops by a small amount, and is still high enough to represent a high likelihood ratio favouring the IO token over S." ] }, { "cell_type": "markdown", "metadata": { "id": "0nI-Mk-fHu9P" }, "source": [ "### Exercise - calculate minimality scores\n", "\n", "> ```yaml\n", "> Difficulty: 🔴🔴🔴🔴🔴\n", "> Importance: 🔵🔵⚪⚪⚪\n", ">\n", "> This exercise is expected to take a long time; at least an hour. It is probably the second most challenging exercise in this notebook.\n", "> ```" ] }, { "cell_type": "markdown", "metadata": { "id": "JRFY4gLjHu9P" }, "source": [ "We'll conclude this section by replicating figure 7 of the paper, which shows the minimality scores for the model.\n", "\n", "Again, this exercise is very challenging and is designed to be done with minimal guidance. You will need to read the relevant sections of the paper which explain this plot: section 4 (experimental validation), from the start up to the end of section 4.2. Note that you won't need to perform the sampling algorithm described on page 10, because we're giving you the set $K$ for each component, in the form of the dictionary below (this is based on the information given in figure 20 of the paper, the \"minimality sets\" table)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "qLBCWfz3Hu9P" }, "outputs": [], "source": [ "K_FOR_EACH_COMPONENT = {\n", " (9, 9): set(),\n", " (10, 0): {(9, 9)},\n", " (9, 6): {(9, 9), (10, 0)},\n", " (10, 7): {(11, 10)},\n", " (11, 10): {(10, 7)},\n", " (8, 10): {(7, 9), (8, 6), (7, 3)},\n", " (7, 9): {(8, 10), (8, 6), (7, 3)},\n", " (8, 6): {(7, 9), (8, 10), (7, 3)},\n", " (7, 3): {(7, 9), (8, 10), (8, 6)},\n", " (5, 5): {(5, 9), (6, 9), (5, 8)},\n", " (5, 9): {(11, 10), (10, 7)},\n", " (6, 9): {(5, 9), (5, 5), (5, 8)},\n", " (5, 8): {(11, 10), (10, 7)},\n", " (0, 1): {(0, 10), (3, 0)},\n", " (0, 10): {(0, 1), (3, 0)},\n", " (3, 0): {(0, 1), (0, 10)},\n", " (4, 11): {(2, 2)},\n", " (2, 2): {(4, 11)},\n", " (11, 2): {(9, 9), (10, 0), (9, 6)},\n", " (10, 6): {(9, 9), (10, 0), (9, 6), (11, 2)},\n", " (10, 10): {(9, 9), (10, 0), (9, 6), (11, 2), (10, 6)},\n", " (10, 2): {(9, 9), (10, 0), (9, 6), (11, 2), (10, 6), (10, 10)},\n", " (9, 7): {(9, 9), (10, 0), (9, 6), (11, 2), (10, 6), (10, 10), (10, 2)},\n", " (10, 1): {(9, 9), (10, 0), (9, 6), (11, 2), (10, 6), (10, 10), (10, 2), (9, 7)},\n", " (11, 9): {(9, 9), (10, 0), (9, 6), (9, 0)},\n", " (9, 0): {(9, 9), (10, 0), (9, 6), (11, 9)},\n", "}" ] }, { "cell_type": "markdown", "metadata": { "id": "eeRTSBzJHu9P" }, "source": [ "Also, given a dictionary `minimality_scores` (which maps heads to their scores), the following code will produce a plot that looks like the one from the paper:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B5F5JywAHu9P" }, "outputs": [], "source": [ "def plot_minimal_set_results(minimality_scores: dict[tuple[int, int], float]):\n", " \"\"\"\n", " Plots the minimality results, in a way resembling figure 7 in the paper.\n", "\n", " minimality_scores:\n", " dict with elements like (9, 9): minimality score for head 9.9 (as described\n", " in section 4.2 of the paper)\n", " \"\"\"\n", "\n", " CIRCUIT_reversed = {head: k for k, v in CIRCUIT.items() for head in v}\n", " colors = [CIRCUIT_reversed[head].capitalize() + \" head\" for head in minimality_scores.keys()]\n", " color_sequence = [px.colors.qualitative.Dark2[i] for i in [0, 1, 2, 5, 3, 6]] + [\"#BAEA84\"]\n", "\n", " bar(\n", " list(minimality_scores.values()),\n", " x=list(map(str, minimality_scores.keys())),\n", " labels={\"x\": \"Attention head\", \"y\": \"Change in logit diff\", \"color\": \"Head type\"},\n", " color=colors,\n", " template=\"ggplot2\",\n", " color_discrete_sequence=color_sequence,\n", " bargap=0.02,\n", " yaxis_tickformat=\".0%\",\n", " legend_title_text=\"\",\n", " title=\"Plot of minimality scores (as percentages of full model logit diff)\",\n", " width=800,\n", " hovermode=\"x unified\",\n", " )" ] }, { "cell_type": "markdown", "metadata": { "id": "ocWm9EsIHu9P" }, "source": [ "Now, you should create the `minimality_scores` dictionary, and use the plot function given above to plot the results:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "U0iqijjfHu9P" }, "outputs": [], "source": [ "minimality_scores = {(9, 9): ...}\n", "plot_minimal_set_results(minimality_scores)\n", "# YOUR CODE HERE - create the `minimality_scores` dictionary, to be used in the plot function given above" ] }, { "cell_type": "markdown", "metadata": { "id": "G3dd5d8fHu9P" }, "source": [ "
\n", "Hint (docstrings of some functions which will be useful for your main function)\n", "\n", "```python\n", "def get_score(\n", " model: HookedTransformer,\n", " ioi_dataset: IOIDataset,\n", " abc_dataset: IOIDataset,\n", " K: set[tuple[int, int]],\n", " C: dict[str, list[tuple[int, int]]],\n", ") -> float:\n", " '''\n", " Returns the value F(C \\ K), where F is the logit diff, C is the\n", " core circuit, and K is the set of circuit components to remove.\n", " '''\n", " pass\n", "\n", "\n", "def get_minimality_score(\n", " model: HookedTransformer,\n", " ioi_dataset: IOIDataset,\n", " abc_dataset: IOIDataset,\n", " v: tuple[int, int],\n", " K: set[tuple[int, int]],\n", " C: dict[str, list[tuple[int, int]]] = CIRCUIT,\n", ") -> float:\n", " '''\n", " Returns the value | F(C \\ K_union_v) - F(C | K) |, where F is\n", " the logit diff, C is the core circuit, K is the set of circuit\n", " components to remove, and v is a head (not in K).\n", " '''\n", " pass\n", "\n", "\n", "def get_all_minimality_scores(\n", " model: HookedTransformer,\n", " ioi_dataset: IOIDataset = ioi_dataset,\n", " abc_dataset: IOIDataset = abc_dataset,\n", " k_for_each_component: dict = K_FOR_EACH_COMPONENT\n", ") -> dict[tuple[int, int], float]:\n", " '''\n", " Returns dict of minimality scores for every head in the model (as\n", " a fraction of F(M), the logit diff of the full model).\n", "\n", " Warning - this resets all hooks at the end (including permanent).\n", " '''\n", " pass\n", "\n", "\n", "minimality_scores = get_all_minimality_scores(model)\n", "\n", "plot_minimal_set_results(minimality_scores)\n", "```\n", "\n", "The output of the third function can be plotted using the plotting function given above.\n", "\n", "
\n", "\n", "\n", "
Solution\n", "\n", "```python\n", "def get_score(\n", " model: HookedTransformer,\n", " ioi_dataset: IOIDataset,\n", " abc_dataset: IOIDataset,\n", " K: set[tuple[int, int]],\n", " C: dict[str, list[tuple[int, int]]],\n", ") -> float:\n", " \"\"\"\n", " Returns the value F(C \\ K), where F is the logit diff, C is the core circuit, and K is the set of circuit components\n", " to remove.\n", " \"\"\"\n", " C_excl_K = {k: [head for head in v if head not in K] for k, v in C.items()}\n", " model = add_mean_ablation_hook(model, abc_dataset, C_excl_K, SEQ_POS_TO_KEEP)\n", " logits = model(ioi_dataset.toks)\n", " score = logits_to_ave_logit_diff_2(logits, ioi_dataset).item()\n", "\n", " return score\n", "\n", "\n", "def get_minimality_score(\n", " model: HookedTransformer,\n", " ioi_dataset: IOIDataset,\n", " abc_dataset: IOIDataset,\n", " v: tuple[int, int],\n", " K: set[tuple[int, int]],\n", " C: dict[str, list[tuple[int, int]]] = CIRCUIT,\n", ") -> float:\n", " \"\"\"\n", " Returns the value | F(C \\ K_union_v) - F(C | K) |, where F is the logit diff, C is the core circuit, K is the set of\n", " circuit components to remove, and v is a head (not in K).\n", " \"\"\"\n", " assert v not in K\n", " K_union_v = K | {v}\n", " C_excl_K_score = get_score(model, ioi_dataset, abc_dataset, K, C)\n", " C_excl_Kv_score = get_score(model, ioi_dataset, abc_dataset, K_union_v, C)\n", "\n", " return abs(C_excl_K_score - C_excl_Kv_score)\n", "\n", "\n", "def get_all_minimality_scores(\n", " model: HookedTransformer,\n", " ioi_dataset: IOIDataset = ioi_dataset,\n", " abc_dataset: IOIDataset = abc_dataset,\n", " k_for_each_component: dict = K_FOR_EACH_COMPONENT,\n", ") -> dict[tuple[int, int], float]:\n", " \"\"\"\n", " Returns dict of minimality scores for every head in the model (as\n", " a fraction of F(M), the logit diff of the full model).\n", "\n", " Warning - this resets all hooks at the end (including permanent).\n", " \"\"\"\n", " # Get full circuit score F(M), to divide minimality scores by\n", " model.reset_hooks(including_permanent=True)\n", " logits = model(ioi_dataset.toks)\n", " full_circuit_score = logits_to_ave_logit_diff_2(logits, ioi_dataset).item()\n", "\n", " # Get all minimality scores, using the `get_minimality_score` function\n", " minimality_scores = {}\n", " for v, K in tqdm(k_for_each_component.items()):\n", " score = get_minimality_score(model, ioi_dataset, abc_dataset, v, K)\n", " minimality_scores[v] = score / full_circuit_score\n", "\n", " model.reset_hooks(including_permanent=True)\n", "\n", " return minimality_scores\n", "\n", "minimality_scores = get_all_minimality_scores(model)\n", "```\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "ew8dKQJoHu9P" }, "source": [ "Note - your results won't be exactly the same as the paper's, because of random error (e.g. the order of importanc of heads within each category might not be the same, especially heads with a small effect on the model like the backup name mover heads). But they should be reasonably similar in their important features." ] }, { "cell_type": "markdown", "metadata": { "id": "6WEak456Hu9P" }, "source": [ "# ☆ Bonus / exploring anomalies\n", "\n", "> ##### Learning Objectives\n", ">\n", "> * Explore other parts of the model (e.g. negative name mover heads, and induction heads)\n", "> * Understand the subtleties present in model circuits, and the fact that there are often more parts to a circuit than seem obvious after initial investigation\n", "> * Understand the importance of the three quantitative criteria used by the paper: **faithfulness**, **completeness** and **minimality**" ] }, { "cell_type": "markdown", "metadata": { "id": "G648T5hNHu9P" }, "source": [ "Here, you'll explore some weirder parts of the circuit which we haven't looked at in detail yet. Specifically, there are three parts to explore:\n", "\n", "* Early induction heads\n", "* Backup name mover heads\n", "* Positional vs token information being moved\n", "\n", "These three sections are all optional, and you can do as many of them as you want (in whichever order you prefer). There will also be some further directions of investigation at the end of this section, which have been suggested either by the authors or by Neel." ] }, { "cell_type": "markdown", "metadata": { "id": "OKkncEC3Hu9P" }, "source": [ "## Early induction heads" ] }, { "cell_type": "markdown", "metadata": { "id": "xctXY0gQHu9P" }, "source": [ "As we've discussed, a really weird observation is that some of the early heads detecting duplicated tokens are induction heads, not just direct duplicate token heads. This is very weird! What's up with that?" ] }, { "cell_type": "markdown", "metadata": { "id": "lMh5w0jOHu9P" }, "source": [ "First off, let's just recap what induction heads are. An induction head is an important type of attention head that can detect and continue repeated sequences. It is the second head in a two head induction circuit, which looks for previous copies of the current token and attends to the token *after* it, and then copies that to the current position and predicts that it will come next. They're enough of a big deal that [Anthropic wrote a whole paper on them](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html).\n", "\n", "The diagram below shows a diagram for how they work to predict that the token `\"urs\"` follows `\" D\"`, the second time the word `\"Dursley\"` appears (note that we assume the model has not been trained on Harry Potter, so this is an example of in-context learning).\n", "\n", "" ] }, { "cell_type": "markdown", "metadata": { "id": "9wvGiic-Hu9P" }, "source": [ "Why is it surprising that induction heads up here? It's surprising because it feels like overkill. The model doesn't care about *what* token comes after the first copy of the subject, just that it's duplicated. And it already has simpler duplicate token heads. My best guess is that it just already had induction heads around and that, in addition to their main function, they *also* only activate on duplicated tokens. So it was useful to repurpose this existing machinery.\n", "\n", "This suggests that as we look for circuits in larger models life may get more and more complicated, as components in simpler circuits get repurposed and built upon." ] }, { "cell_type": "markdown", "metadata": { "id": "dx4_ZN-1Hu9P" }, "source": [ "First, in the cell below, you should visualise the attention patterns of the induction heads (`5.5` and `6.9`) on sequences containg repeated tokens, and verify that they are attending to the token *after* the previous instance of that same token. You might want to repeat code you wrote in the \"Validation of duplicate token heads\" section." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rXRiAcAaHu9P" }, "outputs": [], "source": [ "model.reset_hooks(including_permanent=True)\n", "\n", "attn_heads = [(5, 5), (6, 9)]\n", "\n", "# Get repeating sequences (note we could also take mean over larger batch)\n", "batch = 1\n", "seq_len = 15\n", "rep_tokens = generate_repeated_tokens(model, seq_len, batch)\n", "\n", "# Run cache (we only need attention patterns for layers 5 and 6)\n", "_, cache = model.run_with_cache(\n", " rep_tokens,\n", " return_type=None,\n", " names_filter=lambda name: name.endswith(\"pattern\") and any(f\".{layer}.\" in name for layer, head in attn_heads),\n", ")\n", "\n", "# Display results\n", "attn = t.stack([cache[\"pattern\", layer][0, head] for (layer, head) in attn_heads])\n", "cv.attention.attention_patterns(\n", " tokens=model.to_str_tokens(rep_tokens[0]),\n", " attention=attn,\n", " attention_head_names=[f\"{layer}.{head}\" for (layer, head) in attn_heads],\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "uCAx445CHu9P" }, "source": [ "One implication of this is that it's useful to categories heads according to whether they occur in simpler circuits, so that as we look for more complex circuits we can easily look for them. This is Hooked to do here! An interesting fact about induction heads is that they work on a sequence of repeated random tokens - notable for being wildly off distribution from the natural language GPT-2 was trained on. Being able to predict a model's behaviour off distribution is a good mark of success for mechanistic interpretability! This is a good sanity check for whether a head is an induction head or not.\n", "\n", "We can characterise an induction head by just giving a sequence of random tokens repeated once, and measuring the average attention paid from the second copy of a token to the token after the first copy. At the same time, we can also measure the average attention paid from the second copy of a token to the first copy of the token, which is the attention that the induction head would pay if it were a duplicate token head, and the average attention paid to the previous token to find previous token heads.\n", "\n", "Note that this is a superficial study of whether something is an induction head - we totally ignore the question of whether it actually does boost the correct token or whether it composes with a single previous head and how. In particular, we sometimes get anti-induction heads which suppress the induction-y token (no clue why!), and this technique will find those too." ] }, { "cell_type": "markdown", "metadata": { "id": "KYr-TbCfHu9P" }, "source": [ "### Exercise - validate prev token heads via patching\n", "\n", "> ```yaml\n", "> Difficulty: 🔴🔴⚪⚪⚪\n", "> Importance: 🔵🔵⚪⚪⚪\n", ">\n", "> This just involves performing a specific kind of patching, with functions you've already written.\n", "> ```\n", "\n", "The paper mentions that heads `2.2` and `4.11` are previous token heads. Hopefully you already validated this in the previous section by plotting the previous token scores (in your replication of Figure 18). But this time, you'll perform a particular kind of path patching to prove that these heads are functioning as previous token heads, in the way implied by our circuit diagram.\n", "\n", "
\n", "Question - what kind of path patching should you perform?\n", "\n", "To show these heads are acting as prev token heads in an induction circuit, we want to perform key-patching (i.e. patch the path from the output of the prev token heads to the key inputs of the induction heads).\n", "\n", "We expect this to worsen performance, because it interrupts the duplicate token signal provided by the induction heads.\n", "
" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RrWD8by3Hu9P" }, "outputs": [], "source": [ "model.reset_hooks(including_permanent=True)\n", "\n", "# YOUR CODE HERE - create `induction_head_key_path_patching_results`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "c78Y84-PHu9P" }, "outputs": [], "source": [ "imshow(\n", " 100 * induction_head_key_path_patching_results,\n", " title=\"Direct effect on Induction Heads' keys\",\n", " labels={\"x\": \"Head\", \"y\": \"Layer\", \"color\": \"Logit diff.
variation\"},\n", " coloraxis=dict(colorbar_ticksuffix=\"%\"),\n", " width=600,\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "4XnkbwFRHu9Q" }, "source": [ "
Solution\n", "\n", "```python\n", "model.reset_hooks(including_permanent=True)\n", "\n", "induction_head_key_path_patching_results = get_path_patch_head_to_heads(\n", " receiver_heads=[(5, 5), (6, 9)], receiver_input=\"k\", model=model, patching_metric=ioi_metric_2\n", ")\n", "\n", "```\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "LkAnnRLRHu9Q" }, "source": [ "You might notice that there are two other heads in the induction heads box in the paper's diagram, both in brackets (heads `5.8` and `5.9`). Can you try and figure out what these heads are doing, and why they're in brackets?" ] }, { "cell_type": "markdown", "metadata": { "id": "x44f2DvyHu9Q" }, "source": [ "
\n", "Hint\n", "\n", "Recall the path patching section, where you applied patching from attention heads' outputs to the value inputs to the S-inhibition heads (hopefully replicating figure 4b from the paper). Try patching the **keys** of the S-inhibition heads instead. Do heads `5.8` and `5.9` stand out here?\n", "
\n", "\n", "
\n", "Answer\n", "\n", "After making the plot suggested in the hint, you should find that patching from `5.8` and `5.9` each have around a 5% negative impact on the logit difference if you patch from them to the S-inhibition heads. We'd like to conclude that these are induction heads, and they're acting to increase the attention paid by `end` to `S2`. Unfortunately, they don't seem to be induction heads according to the results of figure 18 in the paper (which we replicated in the section \"validation of early heads\").\n", "\n", "It turns out that they're acting as induction heads on this distribution, but not in general. A simple way we can support this hypothesis is to look at attention patterns (we find that `S2` pays attention to `S1+1` in both head `5.8` and `5.9`).\n", "
\n", "\n", "
\n", "Aside - a lesson in nonlinearity\n", "\n", "In the dropdown above, I claimed that `5.8` and `5.9` were doing induction on this distribution. A more rigorous way to test this would be to path patch to the keys of these heads, and see if either of the previous token heads have a large effect. If they do, this is very strong evidence for induction. However, it turns out that neither of the prev token heads affect the model's IOI performance via their path through the fuzzy induction heads. Does this invalidate our hypothesis?\n", "\n", "In fact, no, because of nonlinearity. The two prev token heads `2.2` and `4.11` will be adding onto the **logits** used to calculate the attention probability from `S2` to `S1+1`, rather than adding onto the probabilities directly. SO it might be the case that one head on its own doesn't increase the attention probability very much, both both heads acting together significantly increase it. (Motivating example: suppose we only had 2 possible source tokens, without either head acting the logits are `[0.0, -8.0]`, and the action of each head is to add `4.0` to the second logit. If one head acts then the probability on the second token goes from basically zero to `0.018` (which will have a very small impact on the attention head's output), but if both heads are acting then the logits are equal and the probability is `0.5` (which obviously has a much larger effect)).\n", "\n", "This is in fact what I found here - after modifying the path patching function to allow more than one sender head, I found that patching from both `2.2` and `4.11` to `5.8` and `5.9` had a larger effect on the model's IOI performance than patching from either head alone (about a 1% drop in performance for both vs almost 0% for either individually). The same effect can be found when we patch from these heads to all four induction heads, although it's even more pronounced (a 3% and 15% drop respectively for the heads individually, vs 31% for both).\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "sbwVBXIeHu9Q" }, "source": [ "## Backup name mover heads" ] }, { "cell_type": "markdown", "metadata": { "id": "UmY5g26IHu9Q" }, "source": [ "Another fascinating anomaly is that of the **backup name mover heads**. A standard technique to apply when interpreting model internals is ablations, or knock-out. If we run the model but intervene to set a specific head to zero, what happens? If the model is robust to this intervention, then naively we can be confident that the head is not doing anything important, and conversely if the model is much worse at the task this suggests that head was important. There are several conceptual flaws with this approach, making the evidence only suggestive, e.g. that the average output of the head may be far from zero and so the knockout may send it far from expected activations, breaking internals on *any* task. But it's still an Hooked technique to apply to give some data.\n", "\n", "But a wild finding in the paper is that models have **built in redundancy**. If we knock out one of the name movers, then there are some backup name movers in later layers that *change their behaviour* and do (some of) the job of the original name mover head. This means that naive knock-out will significantly underestimate the importance of the name movers." ] }, { "cell_type": "markdown", "metadata": { "id": "ONq6NjPUHu9Q" }, "source": [ "Let's test this! Let's ablate the most important name mover (which is `9.9`, as the code below verifies for us) on just the `end` token, and compare performance." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u7m7aQOLHu9Q" }, "outputs": [], "source": [ "model.reset_hooks(including_permanent=True)\n", "\n", "ioi_logits, ioi_cache = model.run_with_cache(ioi_dataset.toks)\n", "original_average_logit_diff = logits_to_ave_logit_diff_2(ioi_logits)\n", "\n", "s_unembeddings = model.W_U.T[ioi_dataset.s_tokenIDs]\n", "io_unembeddings = model.W_U.T[ioi_dataset.io_tokenIDs]\n", "logit_diff_directions = io_unembeddings - s_unembeddings # [batch d_model]\n", "\n", "per_head_residual, labels = ioi_cache.stack_head_results(layer=-1, return_labels=True)\n", "per_head_residual = einops.rearrange(\n", " per_head_residual[:, t.arange(len(ioi_dataset)).to(device), ioi_dataset.word_idx[\"end\"].to(device)],\n", " \"(layer head) batch d_model -> layer head batch d_model\",\n", " layer=model.cfg.n_layers,\n", ")\n", "\n", "per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, ioi_cache, logit_diff_directions)\n", "\n", "top_layer, top_head = topk_of_Nd_tensor(per_head_logit_diffs, k=1)[0]\n", "print(f\"Top Name Mover to ablate: {top_layer}.{top_head}\")\n", "\n", "# Getting means we can use to ablate\n", "abc_means = ioi_circuit_extraction.compute_means_by_template(abc_dataset, model)[top_layer]\n", "\n", "\n", "# Define hook function and add to model\n", "def ablate_top_head_hook(z: Float[Tensor, \"batch pos head_index d_head\"], hook):\n", " \"\"\"\n", " Ablates hook by patching in results\n", " \"\"\"\n", " z[range(len(ioi_dataset)), ioi_dataset.word_idx[\"end\"], top_head] = abc_means[\n", " range(len(ioi_dataset)), ioi_dataset.word_idx[\"end\"], top_head\n", " ]\n", " return z\n", "\n", "\n", "model.add_hook(utils.get_act_name(\"z\", top_layer), ablate_top_head_hook)\n", "\n", "# Run the model, temporarily adds caching hooks and then removes *all* hooks after running, including the ablation hook.\n", "ablated_logits, ablated_cache = model.run_with_cache(ioi_dataset.toks)\n", "rprint(\n", " \"\\n\".join(\n", " [\n", " f\"{original_average_logit_diff:.4f} = Original logit diff\",\n", " f\"{per_head_logit_diffs[top_layer, top_head]:.4f} = Direct Logit Attribution of top name mover head\",\n", " f\"{original_average_logit_diff - per_head_logit_diffs[top_layer, top_head]:.4f} = Naive prediction of post ablation logit diff\",\n", " f\"{logits_to_ave_logit_diff_2(ablated_logits):.4f} = Logit diff after ablating L{top_layer}H{top_head}\",\n", " ]\n", " )\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "HiITENnnHu9Q" }, "source": [ "What's going on here? We calculate the logit diff for our full model, and how much of that is coming directly from head `9.9`. Given this, we come up with an estimate for what the logit diff will fall to when we ablate this head. In fact, performance is **much** better than this naive prediction.\n", "\n", "Why is this happening? As before, we can look at the direct logit attribution of each head to get a sense of what's going on." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "u2CflzapHu9Q" }, "outputs": [], "source": [ "per_head_ablated_residual, labels = ablated_cache.stack_head_results(layer=-1, return_labels=True)\n", "per_head_ablated_residual = einops.rearrange(\n", " per_head_ablated_residual[:, t.arange(len(ioi_dataset)).to(device), ioi_dataset.word_idx[\"end\"].to(device)],\n", " \"(layer head) batch d_model -> layer head batch d_model\",\n", " layer=model.cfg.n_layers,\n", ")\n", "per_head_ablated_logit_diffs = residual_stack_to_logit_diff(\n", " per_head_ablated_residual, ablated_cache, logit_diff_directions\n", ")\n", "per_head_ablated_logit_diffs = per_head_ablated_logit_diffs.reshape(model.cfg.n_layers, model.cfg.n_heads)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mTnbVfE7Hu9Q" }, "outputs": [], "source": [ "imshow(\n", " t.stack([per_head_logit_diffs, per_head_ablated_logit_diffs, per_head_ablated_logit_diffs - per_head_logit_diffs]),\n", " title=\"Direct logit contribution by head, pre / post ablation\",\n", " labels={\"x\": \"Head\", \"y\": \"Layer\"},\n", " facet_col=0,\n", " facet_labels=[\"No ablation\", \"9.9 is ablated\", \"Change in head contribution post-ablation\"],\n", " width=1200,\n", ")\n", "\n", "scatter(\n", " y=per_head_logit_diffs.flatten(),\n", " x=per_head_ablated_logit_diffs.flatten(),\n", " hover_name=labels,\n", " range_x=(-1, 1),\n", " range_y=(-2, 2),\n", " labels={\"x\": \"Ablated\", \"y\": \"Original\"},\n", " title=\"Original vs Post-Ablation Direct Logit Attribution of Heads\",\n", " width=600,\n", " add_line=\"y=x\",\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "ZZe0VW19Hu9Q" }, "source": [ "The first plots show us that, after we ablate head `9.9`, while its direct contribution to the logit diff falls (obviously), a lot of contributions from other heads (particularly in layer 10) actually increase. The second plot shows this in a different way (the distinctive heads in the right hand heatmap are the same as the heads lying well below the y=x line in the scatter plot).\n", "\n", "One natural hypothesis is that this is because the final LayerNorm scaling has changed, which can scale up or down the final residual stream. This is slightly true, and we can see that the typical head is a bit off from the x=y line. But the average LN scaling ratio is 1.04, and this should uniformly change *all* heads by the same factor, so this can't be sufficient." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dIhckVxhHu9Q" }, "outputs": [], "source": [ "ln_scaling_no_ablation = ioi_cache[\"ln_final.hook_scale\"][\n", " t.arange(len(ioi_dataset)), ioi_dataset.word_idx[\"end\"]\n", "].squeeze()\n", "ln_scaling_ablated = ablated_cache[\"ln_final.hook_scale\"][\n", " t.arange(len(ioi_dataset)), ioi_dataset.word_idx[\"end\"]\n", "].squeeze()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "dQMNvMAOHu9Q" }, "outputs": [], "source": [ "scatter(\n", " y=ln_scaling_ablated,\n", " x=ln_scaling_no_ablation,\n", " labels={\"x\": \"No ablation\", \"y\": \"Ablation\"},\n", " title=f\"Final LN scaling factors compared (ablation vs no ablation)
Average ratio = {(ln_scaling_no_ablation / ln_scaling_ablated).mean():.4f}\",\n", " width=700,\n", " add_line=\"y=x\"\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "qv_Mgz-0Hu9Q" }, "source": [ "**Exercise to the reader:** Can you finish off this analysis? What's going on here? Why are the backup name movers changing their behaviour? Why is one negative name mover becoming significantly less important?" ] }, { "cell_type": "markdown", "metadata": { "id": "zLbO8dU5Hu9Q" }, "source": [ "## Positional vs token information being moved\n", "\n", "In section A of the appendix (titled **Disentangling token and positional signal in the output of S-Inhibition Heads**), the authors attempt to figure out whether the S-Inhibition heads are using token or positional information to supress the attention paid to `S1`. This is illustrated in my IOI diagram, by purple vs pink boxes.\n", "\n", "The way the authors find out which one is which is ingenious. They construct datasets from the original IOI dataset, with some of the signals erased or flipped. For instance, if they want to examine the effect of inverting the positional information but preserving the token information written by the S-inhibition heads, they can replace sentences like:\n", "\n", "```\n", "When Mary and John went to the store, John gave a drink to Mary\n", "```\n", "\n", "with:\n", "\n", "```\n", "When John and Mary went to the store, John gave a drink to Mary\n", "```\n", "\n", "Let's be exactly clear on why this works. The information written to the `end` token position by the S-inhibition heads will be some combination of \"don't pay attention to the duplicated token\" and \"don't pay attention to the token that's in the same position as the duplicated token\". If we run our model on the first sentence above but then patch in the second sentence, then:\n", "\n", "* The **\"don't pay attention to the duplicated token\"** signal will be unchanged (because this signal still refers to John)\n", "* The **\"don't pay attention to the token that's in the same position as the duplicated token\"** signal will flip (because this information points to the position of `Mary` in the second sentence, hence to the position of `John` in the first sentence)." ] }, { "cell_type": "markdown", "metadata": { "id": "Ta2R_x34Hu9Q" }, "source": [ "That was just one example (flipping positional information, keeping token information the same), but we can do any of six different types of flip:" ] }, { "cell_type": "markdown", "metadata": { "id": "kr3IsH2bHu9Q" }, "source": [ "| Token signal | Positional signal | Sentence | ABB -> ? |\n", "| ------------ | ----------------- | ----------------------------------------------------------------- | -------- |\n", "| Same | Same | `When Mary and John went to the store, John gave a drink to Mary` | ABB |\n", "| Random | Same | `When Emma and Paul went to the store, Paul gave ...` | CDD |\n", "| Inverted | Same | `When John and Mary went to the store, Mary gave ...` | BAA |\n", "| Same | Inverted | `When John and Mary went to the store, John gave ...` | BAB |\n", "| Random | Inverted | `When Paul and Emma went to the store, Emma gave ...` | DCD |\n", "| Inverted | Inverted | `When Mary and John went to the store, Mary gave ...` | ABA |" ] }, { "cell_type": "markdown", "metadata": { "id": "_E9bFj8tHu9Q" }, "source": [ "We use the `gen_flipped_prompts` method to generate each of these datasets:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fs7nYDpKHu9Q" }, "outputs": [], "source": [ "datasets: list[tuple[tuple, str, IOIDataset]] = [\n", " ((0, 0), \"original\", ioi_dataset),\n", " ((1, 0), \"random token\", ioi_dataset.gen_flipped_prompts(\"ABB->CDD, BAB->DCD\")),\n", " ((2, 0), \"inverted token\", ioi_dataset.gen_flipped_prompts(\"ABB->BAA, BAB->ABA\")),\n", " ((0, 1), \"inverted position\", ioi_dataset.gen_flipped_prompts(\"ABB->BAB, BAB->ABB\")),\n", " ((1, 1), \"inverted position, random token\", ioi_dataset.gen_flipped_prompts(\"ABB->DCD, BAB->CDD\")),\n", " ((2, 1), \"inverted position, inverted token\", ioi_dataset.gen_flipped_prompts(\"ABB->ABA, BAB->BAA\")),\n", "]\n", "\n", "results = t.zeros(3, 2).to(device)\n", "\n", "s2_inhibition_heads = CIRCUIT[\"s2 inhibition\"]\n", "layers = set(layer for layer, head in s2_inhibition_heads)\n", "\n", "names_filter = lambda name: name in [utils.get_act_name(\"z\", layer) for layer in layers]\n", "\n", "\n", "def patching_hook_fn(z: Float[Tensor, \"batch seq head d_head\"], hook: HookPoint, cache: ActivationCache):\n", " heads_to_patch = [head for layer, head in s2_inhibition_heads if layer == hook.layer()]\n", " z[:, :, heads_to_patch] = cache[hook.name][:, :, heads_to_patch]\n", " return z\n", "\n", "\n", "for (row, col), desc, dataset in datasets:\n", " # Get cache of values from the modified dataset\n", " _, cache_for_patching = model.run_with_cache(dataset.toks, names_filter=names_filter, return_type=None)\n", "\n", " # Run model on IOI dataset, but patch S-inhibition heads with signals from modified dataset\n", " patched_logits = model.run_with_hooks(\n", " ioi_dataset.toks, fwd_hooks=[(names_filter, partial(patching_hook_fn, cache=cache_for_patching))]\n", " )\n", "\n", " # Get logit diff for patched results\n", " # Note, we still use IOI dataset for our \"correct answers\" reference point\n", " results[row, col] = logits_to_ave_logit_diff_2(patched_logits, ioi_dataset)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "_151QcaQHu9Q" }, "outputs": [], "source": [ "imshow(\n", " results,\n", " labels={\"x\": \"Positional signal\", \"y\": \"Token signal\"},\n", " x=[\"Original\", \"Inverted\"],\n", " y=[\"Original\", \"Random\", \"Inverted\"],\n", " title=\"Logit diff after changing all S2 inhibition heads' output signals via patching\",\n", " text_auto=\".2f\",\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "9Mzm0AajHu9Q" }, "source": [ "What are your interpretations of this plot?\n", "\n", "
\n", "Some thoughts\n", "\n", "A few sanity checks, which are what we'd expect from this plot (and hence validate our code as probably being correct):\n", "\n", "* When token and positional signals are inverted, performance is close to negative of the original performance.\n", "* Inverting the positional signal makes performance worse.\n", "* Randomizing the token signal makes performance worse.\n", "* Inverting the token signal makes performance worse. This effect is larger than randomizing (because we're pointing away from a correct answer, rather than just in a random direction).\n", "\n", "There are two main interesting findings from this plot (which we might have guessed at beforehand, but couldn't have been certain about):\n", "* The row and column differences are close to constant (column diff is approx 4.5, row diff is approx 0.75). In other words, the logit diff can be well approximated by a linear combination of the positional and token signal correlations (where correlation is 1 if the signal points towards the correct value, -1 if it points away, and 0 if it points in a random direction).\n", "* The coefficients on the positional signal correlation is **much** bigger than the coefficient on the token signal correlation. The former is about 4.5, the latter is about 1.5. This tells us that positional information is a lot more important than token information.\n", " * One possible intuition here is that name information (i.e. representing the identity of the token `\" John\"`) takes up many dimensions, so is probably harder for the model. Relative positional information on the other hand will mostly have fewer dimensions. Since the model only needs to move enough information to single out a single positional index from about 15-20, rather than single out a name in the entire set of names. This is jus a guess though, and I'd love to hear other interpretations.\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "xXwslKJxHu9Q" }, "source": [ "Let's dig a little deeper. Rather than just looking at the S-inhibition heads collectively, we can look at each of them individually." ] }, { "cell_type": "markdown", "metadata": { "id": "xBNUmzrGHu9Q" }, "source": [ "### Exercise - decompose S-Inhibition heads\n", "\n", "> ```yaml\n", "> Difficulty: 🔴🔴⚪⚪⚪\n", "> Importance: 🔵🔵🔵⚪⚪\n", ">\n", "> You should spend up to 10-15 minutes on this exercise.\n", "> This involves a lot of duplicating code from above.\n", "> ```\n", "\n", "Make the same plot as above, but after intervening on each of the S-inhibition heads individually.\n", "\n", "You can do this by creating a `results` tensor of shape `(M, 3, 2)` where `M` is the number of S-inhibition heads, and each slice contains the results of intervening on that particular head. We've given you the code to plot your results below, all you need to do is fill in `results`.\n", "\n", "(Note - we recommend computing the results as `(logit_diff - clean_logit_diff) / clean_logit_diff`, so your baseline is 0 for \"this patching has no effect\" and -1 for \"this patching completely destroys model performance\", to make the plot look clearer.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "x9RZ7ACJHu9Q" }, "outputs": [], "source": [ "results = t.zeros(len(CIRCUIT[\"s2 inhibition\"]), 3, 2).to(device)\n", "\n", "\n", "YOUR CODE HERE - fill in the `results` tensor!\n", "\n", " imshow(\n", " (results - results[0, 0, 0]) / results[0, 0, 0],\n", " labels={\"x\": \"Positional signal\", \"y\": \"Token signal\"},\n", " x=[\"Original\", \"Inverted\"],\n", " y=[\"Original\", \"Random\", \"Inverted\"],\n", " title=\"Logit diff after patching individual S2 inhibition heads (as proportion of clean logit diff)\",\n", " facet_col=0,\n", " facet_labels=[f\"{layer}.{head}\" for (layer, head) in CIRCUIT[\"s2 inhibition\"]],\n", " facet_col_spacing=0.08,\n", " width=1100,\n", " text_auto=\".2f\",\n", " )" ] }, { "cell_type": "markdown", "metadata": { "id": "I5KIFHhrHu9Q" }, "source": [ "
Solution\n", "\n", "```python\n", "results = t.zeros(len(CIRCUIT[\"s2 inhibition\"]), 3, 2).to(device)\n", "\n", "\n", "def patching_hook_fn(z: Float[Tensor, \"batch seq head d_head\"], hook: HookPoint, cache: ActivationCache, head: int):\n", " z[:, :, head] = cache[hook.name][:, :, head]\n", " return z\n", "\n", "\n", "for i, (layer, head) in enumerate(CIRCUIT[\"s2 inhibition\"]):\n", " model.reset_hooks(including_permanent=True)\n", "\n", " hook_name = utils.get_act_name(\"z\", layer)\n", "\n", " for (row, col), desc, dataset in datasets:\n", " # Get cache of values from the modified dataset\n", " _, cache_for_patching = model.run_with_cache(\n", " dataset.toks, names_filter=lambda name: name == hook_name, return_type=None\n", " )\n", "\n", " # Run model on IOI dataset, but patch S-inhibition heads with signals from modified dataset\n", " patched_logits = model.run_with_hooks(\n", " ioi_dataset.toks,\n", " fwd_hooks=[(hook_name, partial(patching_hook_fn, cache=cache_for_patching, head=head))],\n", " )\n", "\n", " # Get logit diff for patched results\n", " # Note, we still use IOI dataset for our \"correct answers\" reference point\n", " results[i, row, col] = logits_to_ave_logit_diff_2(patched_logits, ioi_dataset)\n", "\n", "imshow(\n", " (results - results[0, 0, 0]) / results[0, 0, 0],\n", " labels={\"x\": \"Positional signal\", \"y\": \"Token signal\"},\n", " x=[\"Original\", \"Inverted\"],\n", " y=[\"Original\", \"Random\", \"Inverted\"],\n", " title=\"Logit diff after patching individual S2 inhibition heads (as proportion of clean logit diff)\",\n", " facet_col=0,\n", " facet_labels=[f\"{layer}.{head}\" for (layer, head) in CIRCUIT[\"s2 inhibition\"]],\n", " facet_col_spacing=0.08,\n", " width=1100,\n", " text_auto=\".2f\",\n", ")\n", "```\n", "
" ] }, { "cell_type": "markdown", "metadata": { "id": "NP3jUxGaHu9Q" }, "source": [ "Noteworthy features of this plot:\n", "\n", "* Every head cares more about positional signal than token signal\n", "* Head `8.6` (the biggest S-inhibition head) cares MUCH more about positional signal, in fact it doesn't care at all about token signal\n", " * Suggests maybe `8.6` was the first head which learned to do this task, and subsequent heads basically just helped out by providing the remainder (which was token signal). Shows the heads specialise.\n", "* The only heads that kinda care about token signal are `7.9` and `8.10` (but they still both care about positional signal almost twice as much)\n", "* The approximation of logit diff as a sum of positional and token signal correlations still seems to hold for each head individually, although the coefficients for each head are different." ] }, { "cell_type": "markdown", "metadata": { "id": "3ibJXw29Hu9Q" }, "source": [ "## Further Reading\n", "\n", "Here is a collection of links for further reading, which haven't already been mentioned:\n", "\n", "* [Some Lessons Learned from Studying Indirect Object Identification in GPT-2 small](https://www.alignmentforum.org/posts/3ecs6duLmTfyra3Gp/some-lessons-learned-from-studying-indirect-object)\n", " * A blog post by the authors of this paper, which goes into more detail about the experiments and results.\n", "* [Causal Scrubbing: a method for rigorously testing interpretability hypotheses [Redwood Research]](https://www.alignmentforum.org/posts/JvZhhzycHu2Yd57RN/causal-scrubbing-a-method-for-rigorously-testing)\n", " * Introduces the idea of causal scubbing, a proposed systematic method for evaluting the quality of mechanistic interpretations" ] }, { "cell_type": "markdown", "metadata": { "id": "aIhe1q9BHu9Q" }, "source": [ "## Suggested topics for further exploration\n", "\n", "Here are some future directions, some suggested by Neel, others by the authors of the paper. Many of these might make good capstone projects!\n", "\n", "* 3 letter acronyms (or more!)\n", "* Converting names to emails.\n", " * An extension task is e.g. constructing an email from a snippet like the following: Name: Neel Nanda; Email: last name dot first name k @ gmail\n", "* Grammatical rules\n", " * Learning that words after full stops are capital letters\n", " * Verb conjugation\n", " * Choosing the right pronouns (e.g. he vs she vs it vs they)\n", " * Whether something is a proper noun or not\n", "* Detecting sentiment (e.g. predicting whether something will be described as good vs bad)\n", "* Interpreting memorisation. E.g., there are times when GPT-2 knows surprising facts like people’s contact information. How does that happen?\n", "* Counting objects described in text. E.g.: I picked up an apple, a pear, and an orange. I was holding three fruits.\n", "* Extensions from Alex Variengien\n", " * Understanding what's happening in the adversarial examples: most notably S-Inhibition Head attention pattern (hard). (S-Inhibition heads are mentioned in the IOI paper)\n", " * Understanding how are positional signal encoded (relative distance, something else?) bonus point if we have a story that include the positional embeddings and that explain how the difference between position is computed (if relative is the right framework) by Duplicate Token Heads / Induction Heads. (hard, but less context dependant)\n", " * What are the role of MLPs in IOI (quite broad and hard)\n", " * What is the role of Duplicate Token Heads outside IOI? Are they used in other Q-compositions with S-Inhibition Heads? Can we describe how their QK circuit implement \"collision detection\" at a parameter level? (Last question is low context dependant and quite tractable)\n", " * What is the role of Negative/ Backup/ regular Name Movers Heads outside IOI? Can we find examples on which Negative Name Movers contribute positively to the next-token prediction?\n", " * What are the differences between the 5 inductions heads present in GPT2-small? What are the heads they rely on / what are the later heads they compose with (low context dependence form IOI)\n", " * Understanding 4.11, (a really sharp previous token heads) at the parameter level. I think this can be quite tractable given that its attention pattern is almost perfectly off-diagonal\n", "* What are the conditions for compensation mechanisms to occur? Is it due to drop-out? (Arthur Conmy is working on this - feel free to reach out to arthur@rdwrs.com )\n", "* Extensions from Arthur Conmy\n", " * Understand IOI in GPT-Neo: it's a same size model but does IOI via composition of MLPs\n", " * Understand IOI in the Stanford mistral models - they all seem to do IOI in the same way, so maybe look at the development of the circuit through training?\n", "* [Help out Redwood Research’s interpretability team by finding heuristics implemented by GPT-2 small](https://www.lesswrong.com/posts/LkBmAGJgZX2tbwGKg/help-out-redwood-research-s-interpretability-team-by-finding)\n", " * This LessWrong post from 6 months ago outlines some features of the IOI task that made it a good choice to study, and suggests other tasks that might meet these criteria / how you might go about finding such tasks." ] }, { "cell_type": "markdown", "metadata": { "id": "F5btVWTmHu9R" }, "source": [ "## Suggested paper replications\n", "\n", "
\n", "\n", "### [A circuit for Python docstrings in a 4-layer attention-only transformer](https://www.lesswrong.com/posts/u6KXXmKFbXfWzoAXn/a-circuit-for-python-docstrings-in-a-4-layer-attention-only)\n", "\n", "This work was produced as part of the SERI ML Alignment Theory Scholars Program (Winter 2022) under the supervision of Neel Nanda. Similar to how the IOI paper searched for in some sense the simplest kind of circuit which required 3 layers, this work was looking for the simplest kind of circuit which required 4 layers. The task they investigated was the **docstring task** - can you predict parameters in the right order, in situations like this:\n", "\n", "```python\n", "def port(self, load, size, files, last):\n", " '''oil column piece\n", "\n", " :param load: crime population\n", " :param size: unit dark\n", " :param\n", "```\n", "\n", "The token that follows should be ` files`, and just like in the case of IOI we can deeply analyze how the transformer solves this task. Unlike IOI, we're looking at a 4-layer transformer which was trained on code (not GPT2-Small), which makes a lot of the analysis cleaner (even though the circuit has more levels of composition than IOI does).\n", "\n", "For an extra challenge, rather than replicating the authors' results, you can try and perform this investigation yourself, without seeing what tools the authors of the paper used! Most will be similar to the ones you've used in the exercises so far.\n", "\n", "This might be a good replication for you if:\n", "\n", "* You enjoyed most/all sections of these exercises, and want to practice using the tools you learned in a different context\n", "* You don't want to try anything *too* far left field from the content of these exercises (this post also comes with a Colab notebook which can be referred to if you're stuck)\n", "\n", "
\n", "\n", "### [Mechanistically interpreting time in GPT-2 small](https://www.lesswrong.com/posts/6tHNM2s6SWzFHv3Wo/mechanistically-interpreting-time-in-gpt-2-small)\n", "\n", "This work was done by a group of independent researchers, supervised by Arthur Conmy. The task was to interpret how GPT2-Small can solve the task of next day prediction, i.e. predicting the next token in sentences like *If today is Monday, tomorrow is*. This replication is easier than the previous one, since the core circuit only contains one attention head rather than a composition of several. For this reason, we'd more strongly encourage trying to do this replication without guidance, i.e. have a crack at it before reading the full post.\n", "\n", "
\n", "\n", "### [How does GPT-2 compute greater-than? Interpreting mathematical abilities in a pre-trained language model](https://openreview.net/pdf?id=p4PckNQR8k)\n", "\n", "This paper came out of the REMIX program run by Redwood Research. It analyses a circuit in GPT2-Small, much like this one. Here, the circuit is for computing greater-than; in other words detecting that sentences like *The war lasted from the year 1732 to the year 17...* will be completed by valid two-digit end years (years > 32). The paper identifies a circuit, explains the role of each circuit component, and also finds related tasks that activate the circuit.\n", "\n", "For an extra challenge, rather than replicating this paper, you can try and perform this investigation yourself, without seeing what tools the authors of the paper used! Many will be similar to the ones you've used in the exercises so far; some will be different. In particular, there will be some analysis of individual neurons in this replication, unlike in this IOI notebook.\n", "\n", "This might be a good replication for you if:\n", "\n", "* You enjoyed most/all sections of these exercises, and want to practice using the tools you learned in a different context\n", "* You don't want to try anything *too* far left field from the content of these exercises (although this is probably more challenging than the docstring circuit, mainly due to the presence of MLPs)\n", "\n", "
\n", "\n", "### [Towards Automated Circuit Discovery for Mechanistic Interpretability](https://arxiv.org/abs/2304.14997) / [Attribution Patching Outperforms Automated Circuit Discovery](https://arxiv.org/abs/2310.10348)\n", "\n", "These two are grouped together, because they both study the possibility of scaling circuit discovery tools using automation. **Automated Circuit Discovery** (ACDC) was a tool discovered by Arthur Conmy & collaborators partially done at Redwood Research. They automate several of the techniques we've used in these exercises (in particular activation & patch patching), iteratively severing connections between components until we're left with a minimal circuit which can still effectively solve the task. **Attribution patching** is slightly different; it approximates the effect of each component on the model's output (and by extension its importance in a particular task) by a simple and computationally cheap first-order approximation.\n", "\n", "Either of these techniques would serve as suitable replication challenges over the course of a week. There's also a lot of work that can still be done in improving or refining these techniques.\n", "\n", "This might be a good replication for you if:\n", "\n", "* You enjoyed the exercises in this section, in particular those on activation and path patching\n", "* You're looking for a project which is more focused on coding and implementation than on theory\n", "* You're excited about exploring ways to scale / automate circuit discovery efficiently" ] } ], "metadata": { "language_info": { "name": "python" }, "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" } }, "nbformat": 4, "nbformat_minor": 0 }