Brilliant Evolution of Transformer Blocks — A Mathematical Deep Dive

Freedom Preetham
Autonomous Agents
Published in
16 min readNov 29, 2023

--

Let me start with saying that this is one of the best evolutions I have come across ever since the inception of Transformers!

Large language models (LLMs) can expand their capabilities through various scaling strategies. The more straightforward approach involves amplifying the computational resources — this is a matter of applied AI engineering and is generally more accessible. However, a more nuanced and arguably more impactful method involves refining the underlying mathematical framework. This latter approach represents the cutting edge of AI research and is an endeavor that few can adeptly navigate.

In this blog, I will deep dive into a groundbreaking innovation in this very domain. Bobby He and Thomas Hofmann from the Department of Computer Science at ETH Zurich have introduced what may be considered the next evolutionary step in transformer technology in their paper “Simplifying Transformer Blocks”.

Paper Summary

In designing deep Transformers, a common approach is to use complex building blocks consisting of intertwined attention and MLP sub-blocks, skip connections, and normalization layers. However, this complexity can make these architectures fragile, where even small changes can significantly impact training speed or make models untrainable.

This study explores simplifications to the standard transformer block based on signal propagation theory and empirical findings. We show that many components, such as skip connections, projection or value parameters, sequential sub-blocks, and normalization layers, can be removed without sacrificing training speed.

Our experiments on autoregressive decoder-only and BERT encoder-only models demonstrate that our simplified transformers achieve comparable training speed and performance to standard transformers while being 15% faster in training throughput and using 15% fewer parameters.

The Machine the Size of the Universe.

First let’s take a look at scaling transformers as is. As said, one of the relatively easy way to scale an LLM is to continually increase the parameter size and training corpus. The gains from this are literally free, as depicted in the LLM scaling laws. There is an entire paper written on this. I will deep dive on this in a separate blog.

Theoretically, transformers can continue to scale with more parameters, larger datasets, and increased computational resources. However, in practice, there are several constraints that limit the extent to which transformers can be scaled.

Here is My Take on LLM Scaling

To deepen our understanding of the scaling limits of transformers, we must consider a mathematical framework that encapsulates the multifaceted aspects of LLM performance. Such a framework must account for the diminishing returns of parameter as data scale increases and the complex interactions between model architecture efficiency and computational constraints.

Let’s define P(n,d,θ) as the potential of an LLM, where n is the number of parameters, d is the size of the training data, and θ encapsulates the efficiency of the model’s architecture. The potential, P, can be described by the recursive function:

where,

  • E represents the efficiency of parameter utilization, a non-linear function that possibly diminishes as the number of parameters increases.
  • δn​ and δd​ signify decrements in the number of parameters and dataset size, respectively, allowing us to examine the marginal utility of additional resources.
  • G embodies the computational constraints, a function that increases superlinearly with n and d, representing the growing computational cost associated with larger models and datasets.
  • D reflects the complexity or difficulty of the dataset, which also grows as the dataset size increases and introduces more challenging learning problems.

The intricate interplay between E, G, and D captures the empirical truths of transformer scaling:

  • Efficiency of Parameter Utilization (E): As the model scales, each additional parameter contributes less to model performance due to the redundant encoding of information, represented by a hyperbolic or other sub-linear function of n.
  • Computational Constraints (G): The computational cost grows faster than linearly with both model size and dataset size, reflecting the superlinear time complexity of transformer training.
  • Dataset Complexity (D): As more data is added, the incremental data complexity increases, requiring more nuanced model capacity to capture this complexity efficiently.
  • Diminishing Returns: The logarithmic term ensures that the growth rate of P is sublinear, aligning with empirical observations of diminished gains as model size increases.

In practice, as n and d increase, P approaches an asymptote due to these factors, which collectively impose practical limits on the effective scaling of LLMs. This saturation point is further compounded by environmental and economic considerations, which are not explicitly included in the model but are critical in real-world applications.

Hence, while it’s mathematically conceivable for the function P to grow indefinitely as n and d approach infinity, real-world constraints create a bounded landscape for LLM scaling. The complex, non-linear growth function P emphasizes the empirical reality that there are indeed practical limits to the scaling of transformer architectures.

So what do we do? Well, this is where AI Research comes to the rescue helping us to progress on the path of innovation through mathematical efficiencies.

Current Attention

The current attention mechanism within a transformer model can be dissected into a multi-layered process, deeply rooted in linear algebra and probability theory. The attention function is a mapping of a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.

If you want to deep dive in a comprehensive way of how transformers work, check “The Sorcery behind GPT — Comprehensive Deconstruction of LLMs!

Also if you are looking to learn more about LLMs here is a series:

The standard attention mechanism, as introduced in the seminal paper “Attention is All You Need”, is given by the equation:

This equation comprises several sophisticated operations:

  • Dot Product QK^T : Each query is matched with all keys through the dot product, creating a matrix of scores representing the degree to which each query aligns with each key.
  • Scaling Factor dk​​: The scores are then scaled down by the inverse square root of the dimension of the keys, mitigating the risk of having excessively large dot product values which could lead to gradient instability.
  • Softmax Application: The softmax function is applied to the scaled scores, which converts them into a set of weights that sum to one — effectively a probability distribution.
  • Weighted Sum with V: The attention weights are then used to take a weighted sum of the value vectors, which is the output of the attention mechanism.

The current transformer architecture sub-block looks like this:

1. Removing Attention Sub-Block Skip Connection

To deepen our mathematical understanding of the modifications proposed in the paper, let’s delve into the intricacies of the attention mechanism and the modifications introduced.

Rank Collapse Issue in Transformers:

In transformers, the rank of the attention matrix is crucial as it reflects the dimensionality of the signal that flows through the network. Rank collapse occurs when this rank is reduced, leading to a limited signal that restricts the model’s capacity to learn complex patterns. Mathematically, if the rank of the self-attention matrix A is reduced significantly, it implies that the matrix becomes closer to a lower-dimensional subspace, thereby losing its ability to capture diverse relationships in the data.

You can read about Rank Collapse in greater detail here: “Deep Dive into Rank Collapse in LLMs

Value-SkipInit Modification:

To counter rank collapse, the Value-SkipInit modification is introduced. The original self-attention mechanism in a transformer can be relooked as follows:

Where WQ, WK, and WV are the weight matrices for queries, keys, and values, respectively, and dk​ is the dimension of the keys.

The Value-SkipInit modification alters this to:

Here, A′(X) is the modified self-attention matrix, α and β are trainable scalars, and I_T​ is an identity matrix of size T×T. This modification ensures that even if the self-attention mechanism leans towards a lower rank, the addition of αIT​ maintains a certain level of identity preservation, counteracting the effect of rank collapse.

Shaped Attention Extension:

Shaped Attention extends this concept by introducing a centering matrix C, leading to:

The centering matrix C is set to the self-attention values when the query-key product is zero, ensuring that the initial state of the attention mechanism is biased towards an identity mapping. This serves to stabilize the early training stages, where the rank of A(X) might not be sufficiently diverse.

Mathematical Implications:

  • Stabilization of Signal Flow: By introducing αIT​, the modification ensures that the flow of information across layers retains a minimum level of diversity, crucial for learning complex patterns.
  • Control over Attention Dynamics: The parameters α, β, and γ provide a mechanism to control the learning dynamics of the attention mechanism, allowing for a balance between identity preservation and learning from data.
  • Initial Bias Towards Identity: The inclusion of C biases the attention mechanism towards an identity mapping initially, providing stability in early training phases and gradually allowing the model to learn more complex dependencies as training progresses.

2. Recovering Training Speed Without Attention Skips:

let’s delve into the role of identity attention matrices in transformers and their impact on training dynamics. This approach, which initializes the attention mechanism with identity matrices, can be conceptualized through a series of intricate equations.

Identity Matrix Initialization:

When initializing the attention mechanism with identity matrices, the attention computation at the initial stage simplifies to:

where,

  • A_init​(X) denotes the self-attention output at initialization.
  • I represents the identity matrix.
  • X is the input matrix.
  • softmax is the softmax function applied to the scaled dot product.
  • dk​ is the dimension of the keys, used for scaling the dot product.
  • The expression IXIT simplifies to X, assuming I is the identity matrix matching the dimensions of X.
  • The term IV simplifies to V if V is compatible with the dimensions of I.

Since IXIT simplifies to X, and assuming IV simplifies to V, the equation reduces to:

In this revised equation:

  • A_init​(X) represents the self-attention output at initialization.
  • X is the input matrix.
  • softmax is the softmax function applied to the scaled input matrix.
  • dk​ is the dimension of the keys, used for scaling.
  • V is the value matrix.

Linear Behavior at Initialization: Given the simplicity of the above formulation at initialization, the transformer’s behavior resembles a linear model. This can be further represented as:

where W_linear​ represents an effective linear transformation induced by the simplified attention mechanism and the remaining parts of the transformer block (like feed-forward networks).

Progression to Complex Representations: As training progresses, the attention matrices evolve away from the identity matrix, allowing the model to capture more complex dependencies. This evolution can be modeled as:

where t denotes the training steps and Q(t), K(t), V(t) evolve with training.

Impact on Training Dynamics: The transition from A_init​ to A_train​ can be modeled using a parameter λ(t) that interpolates between the initial linear behavior and the complex attention mechanism:

where,

  • A(X,t) represents the self-attention output at time t.
  • λ(t) is a time-dependent parameter that interpolates between the initial and trained states of the attention mechanism.
  • A_init​(X) is the initial self-attention output.
  • A_train​(X,t) is the self-attention output during training at time t.
  • The equation combines these two states, gradually transitioning from the initial to the trained state as λ(t) evolves over time.

λ(t) decreases from 1 to 0 as training progresses.

This mathematical framework reveals how initializing the transformer with identity attention matrices influences its early training behavior, leading to a predictable and stable learning phase that gradually transitions to capturing more complex patterns. This approach balances the need for stability in early training with the transformer’s capacity for developing intricate representations as it learns.

3. Eliminating Value and Projection Parameters:

The removal of value and projection parameters simplifies the internal structure of the transformer. Setting these matrices to identity essentially turns the relevant components of the transformer into linear layers, reducing the complexity of the operations they need to perform.

let’s look at the mathematical formulations which will help to elucidate how setting the value and projection matrices to identity simplifies the transformer’s internal structure.

Simplification to Identity:

By setting the value (WV) and projection (P) matrices to the identity matrix, the self-attention mechanism simplifies significantly. The new formulation becomes:

  • Self-Attention_identity​(X) denotes the output of the self-attention layer when value and projection matrices are set to identity.
  • Q and K are the query and key matrices, typically computed as Q=XW^Q and K=XW^K where W^Q and W^K are the corresponding weight matrices.
  • Instead of using a value matrix V, the input matrix X itself is used in this case.
  • The softmax function is applied to the scaled dot product of Q and K, and the result is multiplied by X.
  • dk​ represents the dimension of the key vectors, used for scaling the dot product in the softmax function.

Mathematical Implications of the Simplification:

  • Linear Transformation: The replacement of V with X turns the value transformation into a linear operation, reducing the complexity of the attention mechanism.
  • Reduced Parameter Space: The elimination of WV from the equation reduces the number of trainable parameters, leading to a more efficient training process and potentially faster convergence.
  • Impact on Model Capacity: While this simplification reduces complexity, it also impacts the model’s capacity to learn complex data representations. The transformer now relies more heavily on the remaining components (like the feed-forward network) to capture complex patterns in the data.
  • Altered Data Flow: With the identity matrices, the flow of data through the network changes. The attention mechanism now directly propagates the input data X scaled by the attention weights, altering the dynamics of information processing within the network.

4. Removing the MLP Sub-Block Skip Connection:

To deepen the mathematical understanding of removing the MLP Sub-block skip connection and the introduction of parallel sub-blocks in transformers, let’s dissect and elaborate on the new architecture involving Multi-Head Attention (MHA) and Multi-Layer Perceptron (MLP) blocks. This approach enhances the model’s efficiency through concurrent processing.

Standard Transformer Mechanism with Sequential Sub-Blocks:

In a conventional transformer, the output of each layer is computed sequentially, first passing through the MHA sub-block and then the MLP sub-block. Mathematically, this can be represented as:

where X_in​ is the input to the layer, and the functions MHA and MLP represent the operations of the Multi-Head Attention and Multi-Layer Perceptron blocks, respectively.

Introduction of Parallel Sub-Blocks:

The modification involves parallel processing of the MHA and MLP sub-blocks. The equation for the output in this parallel setting becomes:

where,

  • X_out​ represents the output of the transformer layer.
  • α_comb​ is a scaling factor for the input contribution.
  • αcomb​, βFF​, and βSA​ are trainable parameters that control the contribution of the input, MLP, and MHA blocks to the output, respectively.
  • Norm represents a normalization function (like Layer Normalization) applied to the input before it is processed by the MLP and MHA blocks.
  • The MLP and MHA blocks process the same normalized input in parallel, and their outputs are summed with the scaled input.

Mathematical Implications of Parallel Processing:

  • Increased Efficiency: By processing the MHA and MLP blocks in parallel, the model can potentially reduce the computational time per layer.
  • Balanced Contribution: The parameters α_comb​, βFF​, and βSA​ allow for a flexible and balanced contribution of the MLP and MHA blocks to the final output, which can be optimized during training.
  • Simplified Skip Connection: The equation introduces a simplified skip connection that aggregates the contributions of both the MHA and MLP blocks. This differs from traditional transformers, where skip connections are usually associated with each sub-block.
  • Impact on Learning Dynamics: This architectural change alters the learning dynamics of the transformer, potentially enabling the model to capture different types of dependencies in the data more efficiently.

5. Removing Normalization Layers

The removal of normalization layers and the adjustment of residual branches aim to maintain balance in the scale of different layers’ outputs, which can be achieved through precise initialization and training dynamics.

Standard Transformer Mechanism with Normalization Layers:

In a typical transformer, normalization layers are used to stabilize training and manage the scale of layer outputs. The output of a layer with normalization can be represented as:

where, LayerNorm is the layer normalization operation, X is the input to the layer, and SubBlock(X) represents the operations within the sub-block (either MHA or MLP).

Removing Normalization and Adjusting Residual Branches:

With the removal of normalization layers, the emphasis shifts to adjusting the residual branches. This adjustment can be mathematically modeled as:

where α_res​ is a trainable scaling factor applied to the sub-block’s output to manage the output scale.

Balancing Output Scale Without Normalization:

To balance the scale of different layers’ outputs without normalization, the following strategies can be employed:

Careful Initialization: The model parameters, especially those in the sub-blocks, can be initialized in a way that prevents large fluctuations in the output scale. This can be represented as:

where, W_sub​ are the weights of the sub-block, and InitializeWeights() is a function that initializes weights in a manner that ensures controlled output scaling.

Training Dynamics Adjustment: During training, dynamic adjustments can be made to the scaling factor α_res​ based on the feedback from the learning process. This adjustment can be formulated as:

where α_res​(t) is the scaling factor at training step t, and AdjustFactor is a function that adjusts α_res​ based on the feedback received during training.

Mathematical Implications of Removing Normalization Layers:

  • The removal of normalization layers necessitates a fine-tuned control over the scale of the outputs across different layers.
  • The adjustable scaling factor α_res​ becomes critical in ensuring that the outputs of different layers remain balanced without the stabilizing effect of normalization.
  • Careful initialization and dynamic adjustment during training help in managing the potential scale discrepancies that might arise due to the absence of normalization.

6. Final Architecture

The proposed final architecture which helps visualize the new simplified transformer is as follows:

Final Simplified Architecture

In’t this simply BRILLIANT?

7. Experimental Analysis

Scaling Depth and Training Speed: The simplified transformer blocks demonstrate improved training speeds at larger depths, as shown in the below figure. They not only train faster but also utilize the extra capacity effectively. In contrast, Value-SkipInit, despite increased capacity, trains slower at greater depths and exhibits poor scalability compared to our models.

The models improve when deeper (dashed, marked lines) vs. shallower (solid lines), unlike V-SkipInit (He et al., 2023).

BERT: In the paper, they have evaluated the effectiveness of the simplified model blocks across various datasets and architectures, including bidirectional encoder-only BERT for masked language modeling and the GLUE benchmark. They have adopted the “Crammed” BERT setup, designed for resource-constrained training on a consumer GPU within a 24-hour window.

The simplified blocks, especially when combined with normalization, matched the pre-training speed of the Crammed Pre-LN baseline within the 24-hour training limit, as shown in figure below. However, removing skip connections without modifying values and projections, as in He et al. (2023), significantly reduced training speed.

Masked language modeling loss vs runtime on a 2080Ti GPU for 24 hours.

The Table below shows that the simplified methods match the performance of the Crammed BERT baseline after fine-tuning on the GLUE benchmark. They observed that Value-SkipInit can recover during fine-tuning, suggesting factors beyond pre-training speed influence performance. Removing normalizations caused instabilities during fine-tuning in some downstream datasets.

The study highlights the effectiveness of the simplified model blocks, especially in resource-constrained scenarios, while acknowledging the importance of various factors in fine-tuning performance.

Training Smaller Models Longer: To align with the trend of training smaller models for extended periods on more data, we conducted experiments using our simplified blocks. When trained with 3× tokens (around 2B tokens), the simplified SAS and SAS-P blocks maintain or exceed the training speed of the Pre-LN block, as shown in Figure below.

Future Directions

The simplifications proposed in the research paper “Simplifying Transformer Blocks” reflect a profound shift in understanding the nuances of transformer architecture. They emphasize the possibility of achieving comparable or even improved model performance with a substantially reduced computational footprint.

A detailed mathematical exposition could involve eigenvalue analysis of the modified attention matrices, spectral decomposition to understand the effects of centering matrix C, and extensive ablation studies to quantify the impact of each simplification step.

In essence, this journey towards simplification not only makes transformers more accessible and efficient but also opens new pathways to understanding deep learning architectures fundamentally. It paves the way for more sustainable AI and machine learning solutions that are both powerful and practical for real-world applications.

--

--