{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "-c8CAtScHu88"
},
"source": [
"# [1.4.1] Indirect Object Identification (exercises)\n",
"\n",
"> **ARENA [Streamlit Page](https://arena-chapter1-transformer-interp.streamlit.app/21_📚_[1.4.1]_Indirect_Object_Identification)**\n",
">\n",
"> **Colab: [exercises](https://colab.research.google.com/github/callummcdougall/ARENA_3.0/blob/main/chapter1_transformer_interp/exercises/part41_indirect_object_identification/1.4.1_Indirect_Object_Identification_exercises.ipynb?t=20250305) | [solutions](https://colab.research.google.com/github/callummcdougall/ARENA_3.0/blob/main/chapter1_transformer_interp/exercises/part41_indirect_object_identification/1.4.1_Indirect_Object_Identification_solutions.ipynb?t=20250305)**\n",
"\n",
"Please send any problems / bugs on the `#errata` channel in the [Slack group](https://join.slack.com/t/arena-uk/shared_invite/zt-2zick19fl-6GY1yoGaoUozyM3wObwmnQ), and ask any questions on the dedicated channels for this chapter of material.\n",
"\n",
"You can collapse each section so only the headers are visible, by clicking the arrow symbol on the left hand side of the markdown header cells.\n",
"\n",
"Links to all other chapters: [(0) Fundamentals](https://arena-chapter0-fundamentals.streamlit.app/), [(1) Transformer Interpretability](https://arena-chapter1-transformer-interp.streamlit.app/), [(2) RL](https://arena-chapter2-rl.streamlit.app/)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wZEmBjgXHu89"
},
"source": [
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "vYkhTw4LHu8-"
},
"source": [
"# Introduction"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VXR4dZ94Hu8-"
},
"source": [
"This notebook / document is built around the [Interpretability in the Wild](https://arxiv.org/abs/2211.00593) paper, in which the authors aim to understand the **indirect object identification circuit** in GPT-2 small. This circuit is resposible for the model's ability to complete sentences like `\"John and Mary went to the shops, John gave a bag to\"` with the correct token \"`\" Mary\"`.\n",
"\n",
"It is loosely divided into different sections, each one with their own flavour. Sections 1, 2 & 3 are derived from Neel Nanda's notebook [Exploratory_Analysis_Demo](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb#scrollTo=WXktSe0CvBdh). The flavour of these exercises is experimental and loose, with a focus on demonstrating what exploratory analysis looks like in practice with the transformerlens library. They skimp on rigour, and instead try to speedrun the process of finding suggestive evidence for this circuit. The code and exercises are simple and generic, but accompanied with a lot of detail about what each stage is doing, and why (plus several optional details and tangents). Section 4 introduces you to the idea of **path patching**, which is a more rigorous and structured way of analysing the model's behaviour. Here, you'll be replicating some of the results of the paper, which will serve to rigorously validate the insights gained from earlier sections. It's the most technically dense of all five sections. Lastly, sections 5 & 6 are much less structured, and have a stronger focus on open-ended exercises & letting you go off and explore for yourself.\n",
"\n",
"Which exercises you want to do will depend on what you're hoping to get out of these exercises. For example:\n",
"\n",
"* You want to understand activation patching - **1, 2, 3**\n",
"* You want to get a sense of how to do exploratory analysis on a model - **1, 2, 3**\n",
"* You want to understand activation and path patching - **1, 2, 3, 4**\n",
"* You want to understand the IOI circuit fully, and replicate the paper's key results - **1, 2, 3, 4, 5**\n",
"* You want to understand the IOI circuit fully, and replicate the paper's key results (but you already understand activation patching) - **1, 2, 4, 5**\n",
"* You want to understand IOI, and then dive deeper e.g. by looking for more circuits in models or investigating anomalies - **1, 2, 3, 4, 5, 6**\n",
"\n",
"*Note - if you find yourself getting frequent CUDA memory errors, you can periodically call `torch.cuda.empty_cache()` to [free up some memory](https://stackoverflow.com/questions/57858433/how-to-clear-gpu-memory-after-pytorch-model-training-without-restarting-kernel).*\n",
"\n",
"Each exercise will have a difficulty and importance rating out of 5, as well as an estimated maximum time you should spend on these exercises and sometimes a short annotation. You should interpret the ratings & time estimates relatively (e.g. if you find yourself spending about 50% longer on the exercises than the time estimates, adjust accordingly). Please do skip exercises / look at solutions if you don't feel like they're important enough to be worth doing, and you'd rather get to the good stuff!"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yaOlw6n6Hu8-"
},
"source": [
"## The purpose / structure of these exercises\n",
"\n",
"At a surface level, these exercises are designed to take you through the indirect object identification circuit. But it's also designed to make you a better interpretability researcher! As a result, most exercises will be doing a combination of:\n",
"\n",
"1. Showing you some new feature/component of the circuit, and\n",
"2. Teaching you how to use tools and interpret results in a broader mech interp context.\n",
"\n",
"A key idea to have in mind during these exercises is the **spectrum from simpler, more exploratory tools to more rigoruous, complex tools**. On the simpler side, you have something like inspecting attention patterns, which can give a decent (but sometimes misleading) picture of what an attention head is doing. These should be some of the first tools you reach for, and you should be using them a lot even before you have concrete hypotheses about a circuit. On the more rigorous side, you have something like path patching, which is a pretty rigorous and effortful tool that is best used when you already have reasonably concrete hypotheses about a circuit. As we go through the exercises, we'll transition from left to right along this spectrum."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sK1K_aVVHu8-"
},
"source": [
"## The IOI task\n",
"\n",
"The first step when trying to reverse engineer a circuit in a model is to identify *what* capability we want to reverse engineer. Indirect Object Identification is a task studied in Redwood Research's excellent [Interpretability in the Wild](https://arxiv.org/abs/2211.00593) paper (see [Neel Nanda's interview with the authors](https://www.youtube.com/watch?v=gzwj0jWbvbo) or [Kevin Wang's Twitter thread](https://threadreaderapp.com/thread/1587601532639494146.html) for an overview). The task is to complete sentences like \"When Mary and John went to the store, John gave a drink to\" with \" Mary\" rather than \" John\".\n",
"\n",
"In the paper they rigorously reverse engineer a 26 head circuit, with 7 separate categories of heads used to perform this capability. The circuit they found roughly breaks down into three parts:\n",
"\n",
"1. Identify what names are in the sentence\n",
"2. Identify which names are duplicated\n",
"3. Predict the name that is *not* duplicated\n",
"\n",
"Why was this task chosen? The authors give a very good explanation for their choice in their [video walkthrough of their paper](https://www.youtube.com/watch?v=gzwj0jWbvbo), which you are encouraged to watch. To be brief, some of the reasons were:\n",
"\n",
"* This is a fairly common grammatical structure, so we should expect the model to build some circuitry for solving it quite early on (after it's finished with all the more basic stuff, like n-grams, punctuation, induction, and simpler grammatical structures than this one).\n",
"* It's easy to measure: the model always puts a much higher probability on the IO and S tokens (i.e. `\" Mary\"` and `\" John\"`) than any others, and this is especially true once the model starts being stripped down to the core part of the circuit we're studying. So we can just take the logit difference between these two tokens, and use this as a metric for how well the model can solve the task.\n",
"* It is a crisp and well-defined task, so less likely to be solved in terms of memorisation of a large bag of heuristics (unlike e.g. tasks like \"predict that the number `n+1` will follow `n`, which as Neel mentions in the video walkthrough is actually much more annoying and subtle than it first seems!).\n",
"\n",
"A terminology note: `IO` will refer to the indirect object (in the example, `\" Mary\"`), `S1` and `S2` will refer to the two instances of the subject token (i.e. `\" John\"`), and `end` will refer to the end token `\" to\"` (because this is the position we take our prediction from, and we don't care about any tokens after this point). We will also sometimes use `S` to refer to the identity of the subject token (rather than referring to the first or second instance in particular)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "kpk_nMt0Hu8-"
},
"source": [
"## Keeping track of your guesses & predictions\n",
"\n",
"There's a lot to keep track of in these exercises as we work through them. You'll be exposed to new functions and modules from transformerlens, new ways to causally intervene in models, all the while building up your understanding of how the IOI task is performed. The notebook starts off exploratory in nature (lots of plotting and investigation), and gradually moves into more technical details, refined analysis, and replication of the paper's results, as we improve our understanding of the IOI circuit. You are recommended to keep a document or page of notes nearby as you go through these exercises, so you can keep track of the main takeaways from each section, as well as your hypotheses for how the model performs the task, and your ideas for how you might go off and test these hypotheses on your own if the notebook were to suddenly end.\n",
"\n",
"If you are feeling extremely confused at any point, you can come back to the dropdown below, which contains diagrams explaining how the circuit works. There is also an accompanying intuitive explanation which you might find more helpful. However, I'd recommend you try and go through the notebook unassisted before looking at these.\n",
"\n",
"
Intuitive explanation of IOI circuit
\n",
"\n",
"First, let's start with an analogy for how transformers work (you can skip this if you've already read [my post](https://www.lesswrong.com/posts/euam65XjigaCJQkcN/an-analogy-for-understanding-transformers)). Imagine a line of people, who can only look forward. Each person has a token written on their chest, and their goal is to figure out what token the person in front of them is holding. Each person is allowed to pass a question backwards along the line (not forwards), and anyone can choose to reply to that question by passing information forwards to the person who asked. In this case, the sentence is `\"When Mary and John went to the store, John gave a drink to Mary\"`. You are the person holding the `\" to\"` token, and your goal is to figure out that the person in front of him has the `\" Mary\"` token.\n",
"\n",
"To be clear about how this analogy relates to transformers:\n",
"* Each person in the line represents a vector in the residual stream. Initially they just store their own token, but they accrue more information as they ask questions and receive answers (i.e. as components write to the residual stream)\n",
"* The operation of an attention head is represented by a question & answer:\n",
" * The person who asks is the destination token, the people who answer are the source tokens\n",
" * The question is the query vector\n",
" * The information *which determines who answers the question* is the key vector\n",
" * The information *which gets passed back to the original asker* is the value vector\n",
"\n",
"Now, here is how the IOI circuit works in this analogy. Each bullet point represents a class of attention heads.\n",
"\n",
"* The person with the second `\" John\"` token asks the question \"does anyone else hold the name `\" John\"`?\". They get a reply from the first `\" John\"` token, who also gives him their location. So he now knows that `\" John\"` is repeated, and he knows that the first `\" John\"` token is 4th in the sequence.\n",
" * These are *Duplicate Token Heads*\n",
"* You ask the question \"which names are repeated?\", and you get an answer from the person holding the second `\" John\"` token. You now also know that `\" John\"` is repeated, and where the first `\" John\"` token is.\n",
" * These are *S-Inhibition Heads*\n",
"* You ask the question \"does anyone have a name that isn't `\" John\"`, and isn't at the 4th position in the sequence?\". You get a reply from the person holding the `\" Mary\"` token, who tells you that they have name `\" Mary\"`. You use this as your prediction.\n",
" * These are *Name Mover Heads*\n",
"\n",
"This is a fine first-pass understanding of how the circuit works. A few other features:\n",
"\n",
"* The person after the first `\" John\"` (holding `\" went\"`) had previously asked about the identity of the person behind him. So he knows that the 4th person in the sequence holds the `\" John\"` token, meaning he can also reply to the question of the person holding the second `\" John\"` token. *(previous token heads / induction heads)*\n",
" * This might not seem necessary, but since previous token heads / induction heads are just a pretty useful thing to have in general, it makes sense that you'd want to make use of this information!\n",
"* If for some reason you forget to ask the question \"does anyone have a name that isn't `\" John\"`, and isn't at the 4th position in the sequence?\", then you'll have another chance to do this.\n",
" * These are *(Backup Name Mover Heads)*\n",
" * Their existance might be partly because transformers are trained with **dropout**. This can make them \"forget\" things, so it's important to have a backup method for recovering that information!\n",
"* You want to avoid overconfidence, so you also ask the question \"does anyone have a name that isn't `\" John\"`, and isn't at the 4th position in the sequence?\" another time, in order to ***anti-***predict the response that you get from this question. *(negative name mover heads)*\n",
" * Yes, this is as weird as it sounds! The authors speculate that these heads \"hedge\" the predictions, avoiding high cross-entropy loss when making mistakes.\n",
"\n",
"Diagram 1 (simple)
\n",
"\n",
"\n",
"\n",
"
Diagram 2 (complex)
\n",
"\n",
"\n",
"\n",
"
refactor_factored_attn_matrices
(optional)fold_ln
, center_unembed
and center_writing_weights
(optional)Performance on answer token:\n", "Rank: 0 Logit: 18.09 Prob: 70.07% Token: | Mary|\n", "\n" ] }, "metadata": {} }, { "output_type": "stream", "name": "stdout", "text": [ "Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|\n", "Top 1th token. Logit: 15.38 Prob: 4.67% Token: | the|\n", "Top 2th token. Logit: 15.35 Prob: 4.54% Token: | John|\n", "Top 3th token. Logit: 15.25 Prob: 4.11% Token: | them|\n", "Top 4th token. Logit: 14.84 Prob: 2.73% Token: | his|\n", "Top 5th token. Logit: 14.06 Prob: 1.24% Token: | her|\n", "Top 6th token. Logit: 13.54 Prob: 0.74% Token: | a|\n", "Top 7th token. Logit: 13.52 Prob: 0.73% Token: | their|\n", "Top 8th token. Logit: 13.13 Prob: 0.49% Token: | Jesus|\n", "Top 9th token. Logit: 12.97 Prob: 0.42% Token: | him|\n" ] }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[1mRanks of the answer tokens:\u001b[0m \u001b[1m[\u001b[0m\u001b[1m(\u001b[0m\u001b[32m' Mary'\u001b[0m, \u001b[1;36m0\u001b[0m\u001b[1m)\u001b[0m\u001b[1m]\u001b[0m\n" ], "text/html": [ "
Ranks of the answer tokens: [(' Mary', 0)]\n", "\n" ] }, "metadata": {} } ], "source": [ "# Here is where we test on a single prompt\n", "# Result: 70% probability on Mary, as we expect\n", "\n", "example_prompt = \"After John and Mary went to the store, John gave a bottle of milk to\"\n", "example_answer = \" Mary\"\n", "utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "_jNBtLU-Hu9A" }, "source": [ "We now want to find a reference prompt to run the model on. Even though our ultimate goal is to reverse engineer how this behaviour is done in general, often the best way to start out in mechanistic interpretability is by zooming in on a concrete example and understanding it in detail, and only *then* zooming out and verifying that our analysis generalises. In section 3, we'll work with a dataset similar to the one used by the paper authors, but this probably wouldn't be the first thing we reached for if we were just doing initial investigations.\n", "\n", "We'll run the model on 4 instances of this task, each prompt given twice - one with the first name as the indirect object, one with the second name. To make our lives easier, we'll carefully choose prompts with single token names and the corresponding names in the same token positions.\n", "\n", "
[\n", " 'When John and Mary went to the shops, John gave the bag to',\n", " 'When John and Mary went to the shops, Mary gave the bag to',\n", " 'When Tom and James went to the park, James gave the ball to',\n", " 'When Tom and James went to the park, Tom gave the ball to',\n", " 'When Dan and Sid went to the shops, Sid gave an apple to',\n", " 'When Dan and Sid went to the shops, Dan gave an apple to',\n", " 'After Martin and Amy went to the park, Amy gave a drink to',\n", " 'After Martin and Amy went to the park, Martin gave a drink to'\n", "]\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[1m[\u001b[0m\n", " \u001b[1m(\u001b[0m\u001b[32m' Mary'\u001b[0m, \u001b[32m' John'\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[1m(\u001b[0m\u001b[32m' John'\u001b[0m, \u001b[32m' Mary'\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[1m(\u001b[0m\u001b[32m' Tom'\u001b[0m, \u001b[32m' James'\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[1m(\u001b[0m\u001b[32m' James'\u001b[0m, \u001b[32m' Tom'\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[1m(\u001b[0m\u001b[32m' Dan'\u001b[0m, \u001b[32m' Sid'\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[1m(\u001b[0m\u001b[32m' Sid'\u001b[0m, \u001b[32m' Dan'\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[1m(\u001b[0m\u001b[32m' Martin'\u001b[0m, \u001b[32m' Amy'\u001b[0m\u001b[1m)\u001b[0m,\n", " \u001b[1m(\u001b[0m\u001b[32m' Amy'\u001b[0m, \u001b[32m' Martin'\u001b[0m\u001b[1m)\u001b[0m\n", "\u001b[1m]\u001b[0m\n" ], "text/html": [ "
[\n", " (' Mary', ' John'),\n", " (' John', ' Mary'),\n", " (' Tom', ' James'),\n", " (' James', ' Tom'),\n", " (' Dan', ' Sid'),\n", " (' Sid', ' Dan'),\n", " (' Martin', ' Amy'),\n", " (' Amy', ' Martin')\n", "]\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[1;35mtensor\u001b[0m\u001b[1m(\u001b[0m\u001b[1m[\u001b[0m\u001b[1m[\u001b[0m \u001b[1;36m5335\u001b[0m, \u001b[1;36m1757\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m1757\u001b[0m, \u001b[1;36m5335\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m4186\u001b[0m, \u001b[1;36m3700\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m3700\u001b[0m, \u001b[1;36m4186\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m6035\u001b[0m, \u001b[1;36m15686\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m15686\u001b[0m, \u001b[1;36m6035\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m \u001b[1;36m5780\u001b[0m, \u001b[1;36m14235\u001b[0m\u001b[1m]\u001b[0m,\n", " \u001b[1m[\u001b[0m\u001b[1;36m14235\u001b[0m, \u001b[1;36m5780\u001b[0m\u001b[1m]\u001b[0m\u001b[1m]\u001b[0m\u001b[1m)\u001b[0m\n" ], "text/html": [ "
tensor([[ 5335, 1757],\n", " [ 1757, 5335],\n", " [ 4186, 3700],\n", " [ 3700, 4186],\n", " [ 6035, 15686],\n", " [15686, 6035],\n", " [ 5780, 14235],\n", " [14235, 5780]])\n", "\n" ] }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": [ "\u001b[3m Prompts & Answers: \u001b[0m\n", "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mPrompt \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mCorrect \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mIncorrect\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━┩\n", "│ When John and Mary went to the shops, John gave the bag to │ ' Mary' │ ' John' │\n", "│ When John and Mary went to the shops, Mary gave the bag to │ ' John' │ ' Mary' │\n", "│ When Tom and James went to the park, James gave the ball to │ ' Tom' │ ' James' │\n", "│ When Tom and James went to the park, Tom gave the ball to │ ' James' │ ' Tom' │\n", "│ When Dan and Sid went to the shops, Sid gave an apple to │ ' Dan' │ ' Sid' │\n", "│ When Dan and Sid went to the shops, Dan gave an apple to │ ' Sid' │ ' Dan' │\n", "│ After Martin and Amy went to the park, Amy gave a drink to │ ' Martin' │ ' Amy' │\n", "│ After Martin and Amy went to the park, Martin gave a drink to │ ' Amy' │ ' Martin' │\n", "└───────────────────────────────────────────────────────────────┴───────────┴───────────┘\n" ], "text/html": [ "
Prompts & Answers: \n", "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┓\n", "┃ Prompt ┃ Correct ┃ Incorrect ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━┩\n", "│ When John and Mary went to the shops, John gave the bag to │ ' Mary' │ ' John' │\n", "│ When John and Mary went to the shops, Mary gave the bag to │ ' John' │ ' Mary' │\n", "│ When Tom and James went to the park, James gave the ball to │ ' Tom' │ ' James' │\n", "│ When Tom and James went to the park, Tom gave the ball to │ ' James' │ ' Tom' │\n", "│ When Dan and Sid went to the shops, Sid gave an apple to │ ' Dan' │ ' Sid' │\n", "│ When Dan and Sid went to the shops, Dan gave an apple to │ ' Sid' │ ' Dan' │\n", "│ After Martin and Amy went to the park, Amy gave a drink to │ ' Martin' │ ' Amy' │\n", "│ After Martin and Amy went to the park, Martin gave a drink to │ ' Amy' │ ' Martin' │\n", "└───────────────────────────────────────────────────────────────┴───────────┴───────────┘\n", "\n" ] }, "metadata": {} } ], "source": [ "prompt_format = [\n", " \"When John and Mary went to the shops,{} gave the bag to\",\n", " \"When Tom and James went to the park,{} gave the ball to\",\n", " \"When Dan and Sid went to the shops,{} gave an apple to\",\n", " \"After Martin and Amy went to the park,{} gave a drink to\",\n", "]\n", "name_pairs = [\n", " (\" Mary\", \" John\"),\n", " (\" Tom\", \" James\"),\n", " (\" Dan\", \" Sid\"),\n", " (\" Martin\", \" Amy\"),\n", "]\n", "\n", "# Define 8 prompts, in 4 groups of 2 (with adjacent prompts having answers swapped)\n", "prompts = [prompt.format(name) for (prompt, names) in zip(prompt_format, name_pairs) for name in names[::-1]]\n", "# Define the answers for each prompt, in the form (correct, incorrect)\n", "answers = [names[::i] for names in name_pairs for i in (1, -1)]\n", "# Define the answer tokens (same shape as the answers)\n", "answer_tokens = t.concat([model.to_tokens(names, prepend_bos=False).T for names in answers])\n", "\n", "rprint(prompts)\n", "rprint(answers)\n", "rprint(answer_tokens)\n", "\n", "table = Table(\"Prompt\", \"Correct\", \"Incorrect\", title=\"Prompts & Answers:\")\n", "\n", "for prompt, answer in zip(prompts, answers):\n", " table.add_row(prompt, repr(answer[0]), repr(answer[1]))\n", "\n", "rprint(table)" ] }, { "cell_type": "markdown", "metadata": { "id": "cqw_akj8Hu9A" }, "source": [ "
rich
libraryLogit differences \n", "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓\n", "┃ Prompt ┃ Correct ┃ Incorrect ┃ Logit Difference ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩\n", "│ When John and Mary went to the shops, John gave the bag to │ ' Mary' │ ' John' │ 3.337 │\n", "│ When John and Mary went to the shops, Mary gave the bag to │ ' John' │ ' Mary' │ 3.202 │\n", "│ When Tom and James went to the park, James gave the ball to │ ' Tom' │ ' James' │ 2.709 │\n", "│ When Tom and James went to the park, Tom gave the ball to │ ' James' │ ' Tom' │ 3.797 │\n", "│ When Dan and Sid went to the shops, Sid gave an apple to │ ' Dan' │ ' Sid' │ 1.720 │\n", "│ When Dan and Sid went to the shops, Dan gave an apple to │ ' Sid' │ ' Dan' │ 5.281 │\n", "│ After Martin and Amy went to the park, Amy gave a drink to │ ' Martin' │ ' Amy' │ 2.601 │\n", "│ After Martin and Amy went to the park, Martin gave a drink to │ ' Amy' │ ' Martin' │ 5.767 │\n", "└───────────────────────────────────────────────────────────────┴───────────┴───────────┴──────────────────┘\n", "\n" ] }, "metadata": {} } ], "source": [ "def logits_to_ave_logit_diff(\n", " logits: Float[Tensor, \"batch seq d_vocab\"],\n", " answer_tokens: Float[Tensor, \"batch 2\"] = answer_tokens,\n", " per_prompt: bool = False,\n", ") -> Float[Tensor, \"*batch\"]:\n", " \"\"\"\n", " Returns logit difference between the correct and incorrect answer.\n", "\n", " If per_prompt=True, return the array of differences rather than the average.\n", " \"\"\"\n", " batch_idx = t.arange(logits.size(0))\n", "\n", " correct = logits[:, -1, :][batch_idx, answer_tokens[:, 0]]\n", " incorrect = logits[:, -1, :][batch_idx, answer_tokens[:, 1]]\n", " if per_prompt: return correct - incorrect\n", " return t.mean(correct-incorrect)\n", "\n", "\n", "tests.test_logits_to_ave_logit_diff(logits_to_ave_logit_diff)\n", "\n", "original_per_prompt_diff = logits_to_ave_logit_diff(original_logits, answer_tokens, per_prompt=True)\n", "print(\"Per prompt logit difference:\", original_per_prompt_diff)\n", "original_average_logit_diff = logits_to_ave_logit_diff(original_logits, answer_tokens)\n", "print(\"Average logit difference:\", original_average_logit_diff)\n", "\n", "cols = [\n", " \"Prompt\",\n", " Column(\"Correct\", style=\"rgb(0,200,0) bold\"),\n", " Column(\"Incorrect\", style=\"rgb(255,0,0) bold\"),\n", " Column(\"Logit Difference\", style=\"bold\"),\n", "]\n", "table = Table(*cols, title=\"Logit differences\")\n", "\n", "for prompt, answer, logit_diff in zip(prompts, answers, original_per_prompt_diff):\n", " table.add_row(prompt, repr(answer[0]), repr(answer[1]), f\"{logit_diff.item():.3f}\")\n", "\n", "rprint(table)" ] }, { "cell_type": "markdown", "metadata": { "id": "cSu6RMvSHu9A" }, "source": [ "
accumulated_resid
attention_heads
plots are behaving weirdly.\"When Mary and John went to the store, John gave a drink to\"
) and our corrupted input will have the subject token flipped (e.g. \"When Mary and John went to the store, Mary gave a drink to\"
). Patching by replacing corrupted residual stream values with clean values is a causal intervention which will allow us to understand precisely which parts of the network are identifying the indirect object. If a component is important, then patching in (replacing that component's corrupted output with its clean output) will reverse the signal that this component produces, hence making performance much better.\n",
"\n",
"Note - the noising and denoising terminology doesn't exactly fit here, since the \"noised dataset\" actually **reverses** the signal rather than erasing it. The reason we're describing this as denoising is more a matter of framing - we're trying to figure out which components / activations are **sufficient** to recover performance, rather than which are **necessary**. If you're ever confused, this is a useful framing to have - **noising tells you what is necessary, denoising tells you what is sufficient.**"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "rW51svu_Hu9J"
},
"source": [
"Question - we could instead have our corrupted sentence be \"When John and Mary went to the store, Mary gave a drink to\"
(i.e. flip all 3 occurrences of names in the sentence). Why do you think we don't do this?\n",
"\n",
"