Understanding Transformer’s Induction Heads

Natisie
9 min readMay 18, 2024

--

Transformer based AI models such as Large Language Models (LLMs) have demonstrated remarkable performance in a number of different tasks. Even though these models are typically trained to predict only the next token/sequence, transformer models seem to have an ability to generalise to a wider range of tasks. Understanding how transformer models achieve such capabilities it is an active area of research in the AI Safety space i.e. Mechanistic Interpretability.

In this article I will explain what are ‘Transformer’s Induction Heads’, the intuition behind these and I will describe how to visualize attention heads using libraries such as TransformerLens and CircuitsViz, and how to identify induction heads — a special type of attention heads. For an indepth introduction to Induction Heads and the broader topic of Mechanistic Interpretability, I recommend the ARENA course as well as the article In-context Learning and Induction Heads and A Mathematical Framework for Transformer Circuits.

What Are Induction Heads?

Induction heads are a special type of attention heads, formed only on transformer models with more than a single layer; and are critical for efficient in-context learning — refer to image below. Whilst, attention heads allow transformer models to focus on different parts of the input data in general; induction heads (as per its name allures to) attend to tokens that would be predicted via inductive reasoning i.e. based on the context rather than the training data per-se. Hence induction heads play a crucial role in pattern recognition and making predictions based on the learned sequences, facilitating the transformer’s ability to understand context from previous data and to predict what might come next in a sequence.

Induction heads particularly shine in their ability to capture and utilize patterns in sequential data. For example, in language processing, they help the model predict the next word or phrase by taking into account the context provided by previous words or phrases.

How Do Induction Heads Work?

Induction heads operate in the transformer through the self-attention mechanism by assigning weights to different segments according to their relevance to the task at hand t. This allows the model to weigh the importance of different words in a sentence regardless of their position. For example, when trying to predict the next word in a sentence, the model uses induction heads to give more attention to relevant words that appeared earlier in the text.

This is achieved through a combination of different weight matrices that work together to dynamically learn from the input data.

  1. Query Weights (WQ): These weights are used to determine what the model wants to know by forming queries based on the input features, essentially asking what parts of the input data are relevant.
  2. Key Weights (WK): These weights are applied to the same input and it helps the model identify which data to look at by creating a set of keys that can be matched against the queries.
  3. Value Weights (WV): Corresponding to each key, value weights provide the actual content (data) that will be used if a query matches a key.
  4. Output Weights (WO): After the matching process, the selected values are combined and transformed by these weights to produce the final output used for the prediction.

Through these mechanisms, induction heads enable transformers to adaptively focus and refocus on different parts of the input data, making them incredibly effective for a range of complex sequential tasks. This dynamic attention to the relevant parts of the data allows transformers to perform well even with long input sequences, maintaining context and coherence in tasks like text generation, translation, and summarization.

Visualizing Attention Heads

Let’s now visualize attention heads with the help of two Python libraries recently developed for mechanistic interpretability: CircuitsVis and TransformerLens. Code snippets used here have been adapted from the ARENA course.

In order to gain an initial intuition on how induction heads are form whilst ignoring the complexity of other components within the transformer architecture i.e. residual stream, a simple transformer model is used. A pre-trained two layer attention only model (attn_only_2L_half) is available in the huggingface hub. This can then be loaded as a HookedTransformer with TransformerLens.

We start by downloading the two layer attention only model from the huggingface hub as per details below.

from huggingface_hub import hf_hub_download

REPO_ID = "callummcdougall/attn_only_2L_half"
FILENAME = "attn_only_2L_half.pth"

weights_path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)

We now load the model (pre-trained weights) as a HookedTransformer with the following configuration

from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import circuitsvis as cv

cfg = HookedTransformerConfig(
d_model=768,
d_head=64,
n_heads=12,
n_layers=2,
n_ctx=2048,
d_vocab=50278,
attention_dir="causal",
attn_only=True, # defaults to False
tokenizer_name="EleutherAI/gpt-neox-20b",
seed=398,
use_attn_result=True,
normalization_type=None, # defaults to "LN", i.e. layernorm with weights & biases
positional_embedding_type="shortformer"
)

model = HookedTransformer(cfg)
pretrained_weights = t.load(weights_path, map_location=device)
model.load_state_dict(pretrained_weights)

Once the model is loaded, TransformerLens allows to run the model either with raw text or tokens as input to the model. For this example we’re using a raw text inputand run the two layer attention only model with the run_with_cahce() function. There are two outputs to this function, the model’s output logits (unnormalised predicitions); and activation patterns per each layer (cache). Please note that remove_batch is set to ‘true’ here because the pre-trained model only have a single batch dimension, hence this parameter is not required for this example.

After running the model with cache, we convert the raw input text to tokens with the TransformerLens function to_str_tokens(); then we extract from the cache the attention head patterns for each layer — both tokens and attention patterns are required to generate the attention head visualization with CircuitsVis as per code snippet below

text_english = "We think that powerful, significantly superhuman machine intelligence is more likely than not to be created this century. If current machine learning techniques were scaled up to this level, we think they would by default produce systems that are deceptive or manipulative, and that no solid plans are known for how to avoid this."

logits, cache = model.run_with_cache(text, remove_batch_dim=True)

str_tokens = model.to_str_tokens(text)

display(cv.attention.from_cache(
tokens=str_tokens,
cache=cache,
layers=[0, 1],
head_notation='LH'))

For instance, when focusing on the token ‘intelligence’ and inspecting head L0H7, one could see that this particular head is attending to the previous token ‘machine’. From a visual inspection of the interactive visualization of the attention patterns, one could distinguish at least three different types.

  • Previous token heads attending primarily to the previous token i.e. L0H7
  • Current token heads attending mostly to the current token i.e. L1H6
  • First token heads which attend primarily to the first token i.e. L0H3, L1H4 and L1H10

Attention Patterns Are Language Agnostic?

Before we look into how to identify induction heads, I was curious whether the observations above are language agnostic. I then translated the original raw text into Spanish and run the model as described before. The results are remarkable — the same attention patterns persist!

Below one can see that after locking the interactive visualization in the token ‘significantly’ in English and the corresponding token in Spanish ‘signific’, and locking the L0H7 head for the two models, both heads attend to the previous token, regardless of the position of the token in the sequence.

How to Identify Induction Heads?

Now that we know how to visualize and inspect attention heads, let’s have a look at a toy example that helps us to identify induction heads. As mentioned before, induction heads are a special case of attention heads which seem to be critical for transformer models in-context learning capabilities. According to the authors of the paper In-context Learning and Induction Heads, induction heads display two main characteristics:

  • Prefix matching: The head attends back to previous tokens that were followed by the current and/or recent tokens. That is, it attends to the token which induction would suggest comes next.
  • Copying: The head’s output increases the logit corresponding to the attended-to token.

With this in mind, we will attempt to identify these characteristics in the attention heads of our toy example below. It has been noted that models with induction heads are capable of predicting the repeating tokens of a random sequence which have a repeated part to it. This is surprising given that this sort of data is out of the training data distribution.

Let’s generate a sequence of random tokens with the second half of the sequence being the same random tokens repeated.

def generate_repeated_tokens(
model: HookedTransformer, seq_len: int, batch: int = 1
) -> Int[Tensor, "batch full_seq_len"]:
'''
Generates a sequence of repeated random tokens

Outputs are:
rep_tokens: [batch, 1+2*seq_len]
'''
prefix = (t.ones(batch, 1) * model.tokenizer.bos_token_id).long()
# SOLUTION
rep_tokens_half = t.randint(0, model.cfg.d_vocab, (batch, seq_len), dtype=t.int64)
rep_tokens = t.cat([prefix, rep_tokens_half, rep_tokens_half], dim=-1).to(device)
return rep_tokens

As in our previous example, we can run the model with cache passing in the generated random tokens.

def run_and_cache_model_repeated_tokens(model: HookedTransformer, seq_len: int, batch: int = 1) -> Tuple[t.Tensor, t.Tensor, ActivationCache]:
'''
Generates a sequence of repeated random tokens, and runs the model on it, returning logits, tokens and cache

Should use the `generate_repeated_tokens` function above

Outputs are:
rep_tokens: [batch, 1+2*seq_len]
rep_logits: [batch, 1+2*seq_len, d_vocab]
rep_cache: The cache of the model run on rep_tokens
'''
rep_tokens = generate_repeated_tokens(model, seq_len, batch)
rep_logits, rep_cache = model.run_with_cache(rep_tokens)
return rep_tokens, rep_logits, rep_cache


seq_len = 50
batch = 1
(rep_tokens, rep_logits, rep_cache) = run_and_cache_model_repeated_tokens(model, seq_len, batch)
rep_cache.remove_batch_dim()
rep_str = model.to_str_tokens(rep_tokens)
model.reset_hooks()
log_probs = get_log_probs(rep_logits, rep_tokens).squeeze()

print(f"Performance on the first half: {log_probs[:seq_len].mean():.3f}")
print(f"Performance on the second half: {log_probs[seq_len:].mean():.3f}")

plot_loss_difference(log_probs, rep_str, seq_len)

One can see that the performance of the model predicting the second part of the sequence improves drastically.

Let’s now visualize the patterns in the attention heads for this run of our two layer attention only model using CircuitsVis as in the previous example

display(cv.attention.from_cache(tokens=rep_str, cache=rep_cache, layers=[0, 1], head_notation='LH'))

Compared to our previous example, there is now a new type of pattern emerging only in the second layer — L1H4, L1H10 and to some lesser extend L1H6. When focusing in the second occurrence of the ‘Adding’ token, heads H4 and H10 are both attending to the previous occurrence of the next token ‘ometric’, which reflects the expected prefix matching characteristic of an induction head.

So far we have relied on a visual inspection to identify induction heads. However one can implement a suitable scoring function instead. An example is provided in the section ‘Finding induction heads’ aspart of the introduction to mechanistic interpretability in the ARENA course.

The in-context learning and induction heads paper provides an indepth understanding of the intuition behind induction heads formation. According to the authors, induction heads tend to form ‘abruptly’ early on during the course of training within a narrow window of about 2.5 to 5 billion tokens. This coincides with the phase change region from previous image, where in-context learning performance shows a sudden increase.

The figure below shows, how in-context learning performance increase for inductions heads, whereas other attention heads don’t exhibit the same radical step change.

Source: In-context Learning and Induction Heads

Conclusion

I hope by now you have gained some general intuition of what are induction heads, why these are key to transformer models capability for in-context learning; and how these differ from the more general attention heads.

Resources

--

--