Transformers for software engineers (in Python)

Stephen Jonany
4 min readJan 10, 2024

--

I recently reimplemented Transformers for software engineers in Python here. Here are some learnings that I got from the exercise.

Transformer interface: Output is a logit per prefix

Code

The decoder-only transformer outputs not just the logit of the token following the entire input context, but one for each prefix!

E.g. if the input is “this is my sentence”, then we will return 4 logits for each of the 4 prefixes like so:

  • this: [is: 0.2, apple: 0.1, …]
  • this is: [my: 0.2, word: 0.1, … ]
  • this is my: [sentence: 0.2, word: 0.1, … ]
  • this is my sentence: [word: 0.1, … ]

This might seem wasteful, since if you are just predicting the next word, only the last logit is useful. However, during training, it’s efficient to get multiple loss numbers in 1 transformer invocation.

Residual stream

Code

Conceptually, the residual stream is the transformer’s internal state at a specific layer. I think of it as N_TOKEN-sized dictionary of opaque data structures (in the code, we call these States).

Changing meanings. The meaning of this N_TOKEN-sized dictionary changes after each layer:

  • First layer. After the first embedding layer, the residual stream is a collection of token embeddings.
  • Attention: context-aware. After the first attention step, this is a set of context-aware embeddings per token. But, remember that each state can only attend to itself and the states before it, but not after.
  • MLP: heavy-weight proecssing. After the first MLP, and for the subsequent layers, it’s something mysterious that we don’t have good intuition for :)
  • Final layer: prefix summary. At the very final layer before unembedding, each State in the residual stream is a summary of a prefix for the input. That is, stream.states[2] is a summary of “I love you”, whereas stream.states[1] is a summary of “I love”. This happens because in all the attention layers, we ensure that states[i] will never depend on the initial token embeddings of tokens after i, no matter which layer we’re on.
  • So, beware of the name ResidualStream sprawling across the layers. They have different meanings — they might as well be called ResidualStreamLayer1, ResidualStreamLayer2, …

State = opaque data structure w/ efficient compression. I liked the blog post’s note that we can think of State, an item in the residual vector, as an opaque data structure that efficiently stores an exponential number of D_EMBED almost-orthogonal vectors. So, think of a POJO where there are e^D_EMBED scalar fields, but the way you access them is not by doing POJO.state[i], but by doing POJO.dot_product(state_query_vector)

Attention layer

Code

Conceptually, this layer makes the State’s in the ResidualStream have cross-state / contextual awareness (but only in the prefix direction). This inter-state mixing can also be thought of as a preparatory stage before we do heavy per-state processing with the MLP layer.

MLP layer

Code

Conceptually, this layer with its fancy non-linearity is what performs the interesting processing on top of the attention-mixed state vectors.

Per-state. The MLP layer is applied to each state in the residual stream independently, unlike the attention layer, which accepts the entire residual stream. It makes sense that you don’t have to look at multiple states, since the attention layer already did the mixing, and all that’s left to do is to process the mixture one at a time.

Neuron: 2 vectors. Each neuron can be thought of as being composed of 2 vectors: read and write vector. The read vector describes how the neurons convert the state vector into a scalar. The write vector describes how the neuron converts the non-linearity-applied scalar into a state update vector. This is a cool alternate view of the more popular matrix formulation: “MLP layer is a multiplication of [N_TOKEN, D_EMBED] residual stream with the [D_EMBED, D_MLP] matrix, followed by element-wise non-linearity, followed by re-projection back with another [D_MLP, D_EMBED] matrix”

Altogether now!

Let’s go through an example:

  • Input: [“this”, “is”, “a”, “sentence”].
  • Embedding. After embedding, if D_EMBED is 2, then we have 4 2-d vectors: [[0.2,0.1], [0.3,0.4], [0.1, 0.2], [0.5, 0.6]]. This is our first ResidualStream, where states[1] = [0.3, 0.4] is the token embedding for “is”
  • First attention. After the first block’s attention layer, we still maintain the same shape of 4 2-d vectors as our residual stream. However, post_attention_stream.states[1] = a mixture of “this” and “is” token embeddings.
  • First MLP. After the first block’s MLP layer, post_mlp_stream.states[1] = a summary of “this is”
  • Then subsequent blocks operate on these prefix-context-aware summaries
  • Unembedding. At the very final layer before unembedding, final_stream.states[1] is a summary of “this is” that has been fully processed. It’s only based on this 1 state that we’re making prediction on what comes next.

Closing thoughts

This exercise was very helpful in helping me conceptualize what the transformer is doing. It makes me ask questions like: “Why do we only have 96 residual blocks of heavy-weight processing? Was there precedent work where we have an RNN to simulate a loop-supporting algorithm on the residual stream?” — I don’t know the answer yet, but this is just one example of how the conceptual framework helps make me ask questions that I’m interested in. So! I would recommend this exercise to anyone who wants to understand transformers conceptually.

--

--

Stephen Jonany

Software engineer at Snowflake ❄️. Previously at Google. Book quotes on engineering, science, productivity, life. linktr.ee/sjonany