Snowflake Arctic Cookbook Series: Building an Efficient Training System for Arctic

On April 24 Snowflake Arctic was released to the world with a key goal in mind — to be truly open. As part of that initiative, the Snowflake AI Research team is going to deliver a series of cookbooks to describe how to pretrain, fine-tune, evaluate, and serve large-scale MoEs such as Arctic. We will share our journey of training the Arctic model, along with our findings related to sourcing and composing pre-training data, designing MoE architecture, co-designing models with training and inference systems in mind, and methods for fine-tuning and evaluating the models.

For the full series you can always find it here in our Snowflake Arctic cookbook catalog.

The training system for Arctic was built on top of DeepSpeed library. DeepSpeed provides us with a suite of essential optimizations to build upon, including ZeRO-2 technology and expert-parallelism for efficient large MoE model training. In this blog, we discuss the training system details and share an in-depth look at performance bottlenecks we identified and optimizations we developed on top of DeepSpeed to overcome them.

Architecture and Parallelism Overview

Figure 1. Dense, MoE and Dense-MoE Hybrid Transformer Architecture

Arctic uses a Dense-MoE hybrid architecture (Figure 1 right) that combines a dense transformer consisting of approximately 10 billion parameters with residual sparse MoE FFN modules. The MoE module has 128 experts, each with about 3.66 billion parameters, associated with a top-2 gating function. In total the model consists of 17 billion active parameters and 480 billion total parameters. Artic model states, a combination of parameters, gradients, and optimizer states, requires roughly 7.5 TB of memory, and as such needs to leverage aggregate GPU memory across hundreds of devices to train the model.

We adopt a combination of two parallelization strategies, ZeRO Data-Parallelism (DP) and Expert-Parallelism (EP), to both leverage aggregate GPU memory across hundreds of devices, and to allow for efficient training when scaling to hundreds or thousands of GPUs. More specifically, we use ZeRO Stage-2 across all GPUs for the dense parameters, and we use expert parallelism to distribute experts among E devices and ZeRO-Stage 2 to further scale to all the remaining GPUs.

For instance, if the total number of GPUs is N, and expert-parallelism degree is E, then the dense parameters are scaled to all N devices using ZeRO Stage 2 while the experts are distributed among E experts, and each expert is scaled to the rest of the N/E GPUs using ZeRO Stage-2. Below figure shows the way we handle these three dimensions of parallelism and the communication collectives required for parallelizing the workload.

Figure 2: This figure shows the way model parameters are distributed across N GPUs. Each device contains two sets of parameters: dense (blue) and sparse (yellow). We also show the multiple parallel-groups needed for the MoE vs Dense parts and also the kind of communication needed for each of them. Regarding the data-parallelism, we use different sizes for dense (N) and (N/E) for the MoE part. Thus, we have different all-reduce operations: one across the same experts of different EP-groups (shown as yellow arrow) and one across the entire devices (shown as green cross).

Identifying Performance Bottlenecks

To identify system bottlenecks, we rely on the statistics provided by DeepSpeed Profiler, which offer quantitative insights into the Tera-Flops efficiency of various model blocks during training forward pass. Additionally, we consider forward-backward compute latencies as well as end-to-end training latency, which captures overall training pipeline bottlenecks including data processing, and communication delays. With this careful breakdown of training latencies, we identified the following main bottlenecks to optimize the training throughput:

Inefficient irregular and sparse operators: Each MoE block consists of several sparse operations, e.g. dispatching the input tokens to the right expert, which are originally implemented with einsum and cumsum operations. In order to perform these operations in the regular fashion, the naive pytorch implementation has to work with huge sparse matrices which results in a lot of wasteful computations. Furthermore, we identified several other operations, such as SwiGlu and Rotary-Position-Embedding, which are ideal targets for kernel-fusion.

Activation-Checkpointing recomputation overhead: Activation-Checkpointing is used to save activation memory by recomputing them in the backward pass. The normal granularity is to treat each transformer layer as the checkpointing boundardy. However, recomputing all of the activations inside the layer adds up to 33% training overhead.

Communication overheads: There are two forms of communication overheads: i) the All-to-All communication of the tokens among experts when using expert-parallelism and ii) the gradient all-reduce associated with ZeRO-Stage 2. Without any further optimizations, we observed the communication overhead to be larger than 50%.

Addressing Performance Bottlenecks

Custom Fused Kernels for Irregular and Sparse Operators

To improve the performance of irregular and sparse operators, we crafted custom kernels through operator fusion. Here, we dive into how we employed this technique to enhance the efficiency of each transformer block:

  1. RoPE Fusion: By fusing small operations needed for rotating and computing positional embeddings for Query and Key heads, we minimize computation costs and memory bandwidth, reducing back-and-forth traffic to GPU global memory. Using this kernel, we get between 5% to 8% overall training speedup.
  2. SwiGlu Fusion: We replace chunking with direct indexing of different matrix portions plus the activation function, further optimizing performance. By utilizing this fusion, we obtain another 8% speedup on the end-to-end training time.
  3. MoE-Gating Fusion: Utilizing two kernels, one computes expert scores alongside top-k selection, while the other scatters input tokens into consecutive buffers assigned for each expert. Additionally, we employ a MoE-Gather kernel to replace sparse matrix multiplication, simplifying the combination of the expert outputs from top-k selection. MoE-gating fusion brings the most benefit of the kernel fusion, resulting in 60% faster training using top-1 and 2.2x for top-2 MoE configuration for the gating operations.

Let’s take a closer look at the fusion techniques applied to the MoE module. MoE-Gating involves various operations with sparse matrices to select tokens for each expert. Here’s an overview of the functions involved in computing the MoE mask and score for token dispatch:

  • Gating Logic computes the initial mask using top-k (k=1, 2, or 4) expert selection for each token, along with a Random-Token-Selection (RTS) mask to ensure the experts’ load balancing. This operation involves generating a costly sparse mask and top-K (K: expert-capacity) token-selection.
  • Indexing uses the mask to index tokens using a sequential cumsum operation.
  • Output Mask generates a 3-dimensional one hot matrix showing where each token needs to run on each expert in the capacity dimension (#tokens x #experts x capacity)

Figure 3. The operations involved in the MoE-gating module to compute the mask for dispatching the input tokens to the experts and also gating weights that are used later for gathering the expert output back in the output tokens.

The gating output mask is used for scattering input tokens and dispatching them to the experts. In PyTorch, this is performed using an einsum operation that includes reshaping plus transposing of the input and mask matrixes followed by a matrix-multiplication. However, these operations incur data-reordering and data-copy overheads, along with wasteful sparse multiplications. To tackle this, we implement custom kernels that efficiently scatter input tokens to the correct experts using metadata generated from previous steps.

Furthermore, the MoE layer needs to place expert outputs back to the original token ordering and multiply them with MoE-gating scores. This involves multiple reshapes and sparse multiplications. In this section, we explore kernel-fusion strategies to eliminate such overheads and restore the efficiency of sparse token computation in the MoE architectures.

Figure 4. The optimized kernels used to replace the sparse and complicated operations of the MoE-gating function.

We divide the operations for the MoE layer into several main kernels:

  • Logit softmax + Top-1/2/4 selection: in this part, we create an assignment table for the input tokens showing which expert is selected and also saves the expert-score for that token which is going to be used after the expert’s execution when reordering the expert’s output.
  • Random-Token-Selection + TopK: in this kernel, we finalize the expert-assignment based on the number of tokens assigned to experts. If more than capacity is assigned at each expert, the RTS part will drop some based on the random probabilities to enforce load-balancing. All tokens are processed by the experts which have lower assignments than capacity.
  • Scatter tokens in contiguous manner: this kernel moves the tokens into the slots for each expert that has been selected, and we store the mapping slot information in a small table with the size of #tokens to be used both in the Gather kernel and also in the backward pass of this module.
  • Gather the experts output: we put the output of each expert in the right token’s mapped slot and multiply it by the MoE scores computed in the first kernel.

By removing all overhead on data reorganization and sparse computation in MoE gating, we significantly reduce training time, achieving an overall 3x speedup for the gating operations compared to basic PyTorch implementation using 64-expert parallelism degree.

Selective Activation-Checkpointing and Quantized Activations

By combining profiling data with communication micro-benchmarking, we pinpointed several costly operations within our model architecture. Some of these operations such as linear layers demand high computational power, while others such as all-to-all incur significant communication volume. Additionally, we identified operations with large activation memory but are inexpensive to recompute such as activation functions.

Leveraging this understanding of compute versus memory cost, we devised a new API that allows each module to be designated for either recomputation or saved for future reuse in the backward pass. This scheme, layered atop of the existing activation checkpointing system, introduces flexibility in reusing parts of the computation graph when memory allows.

Compared to the existing selective activation recomputation approach, we provide a finer-grained control over the different modules that can be checkpointed or recomputed. In particular, we can selectively recompute or skip MoE gating, expert-input scatter, and all-2-all communication, as well as all the linear layers.

However, by introducing the new customized and flexible activation checkpointing, the memory pressure becomes higher even when ZeRO Stage-2 and Expert Parallelism are already enabled. To alleviate this scenario, we also added support to save activations in quantized form, and dequantize them as they are needed in backward. This allows us to free more memory to save more activations during forward pass, specially for costly operations like GEMMs where dequantization is significantly faster than recomputation.

By integrating these strategies, we strike a balance between computational efficiency and memory usage, optimizing our training process for enhanced performance.

Reducing Communication Overhead

Communication aware Parallelization-Topology

Both expert and data parallelism introduce communication overhead when distributing work across multiple resources. Expert-parallelism necessitates all-to-all operations to scatter tokens among parallelized experts across E devices, while data-parallelism requires gradient averaging through all-reduce operations across the data-parallel dimension before parameter updates.

In deciding how to map expert and data parallel groups to the GPUs, we have two choices: a) E+D where the experts are placed close to each other and each expert is replicated with E-device stride to make the data-parallel group; b) D+E that we first replicate each expert by N/E (N is the total number of devices), and therefore the expert-parallel is across farther GPUs.

Figure 5. The two different layouts for combining expert and data parallelism for the MoE model architecture: a) Expert + Data (E + D) layout, b) Data + Expert (D + E) layout.

These two choices have a massive influence on the communication overhead.

At a first glance, using E+D may seem to be a better option as this reduces the number of inter-node hops for the all-to-all, maximizing the all-to-all throughput for the relevant message sizes. However, in reality, this is highly suboptimal.

Note that the inter-node communication volume for all-reduce is proportional to the number of parameters within a node. By placing experts close to each other, we are placing up to 8 experts within each node compared to a single expert when using D + E layout. As such the former can incur up to 8x higher communication overhead for all-reduce compared to the latter, resulting in a significantly higher communication overhead. On the other hand the communication volume for all-to-all simply does not change regardless of the mapping. As such, by prioritizing data-before-expert parallelism, we optimize communication efficiency, enabling faster training across distributed devices.

Communication overlapping

While using D+E mapping incurs less overall communication overhead than E+D, the all-to-all communication still accounts for a significant portion of the training time. To address this, we’ve implemented two overlapping techniques:

Stream-Based Overlapping for MoE Architecture: We split expert computation and communication across multiple streams. This involves utilizing a compute stream as the main PyTorch stream and a separate stream dedicated to handling communication. Additionally, we partition the expert’s GeMM into two chunks to preserve tensor-core efficiency while overlapping communication. This approach allows us to overlap half of the first and second all-to-all’s time with the computation of the expert’s GeMM calculation.

System Architecture Co-design for Arctic: To allow for easy overlap of all-to-all communication with compute, we created a Dense-MoE Hybrid architecture that allows us to easily overlap the communication in the MoE path with the computation in the dense path. In fact, this modification allows us to overlap each all-to-all with different parts of the transformer layer: for example, the first all-to-all with Attention Computation and the second all-to-all with the dense MLP computation.

We’ve also extended this overlapping scheme to the backward pass by explicitly signaling Torch’s graph of computation graph dependency breaks. By waiting on the second stream used for running the MoE part in the middle of the transformer layer, we’ve effectively eliminated communication overhead. This optimization comes at the cost of adding a small percentage more parameters to the model, resulting in a substantial increase in training system efficiency.

Conclusion

Training large MoE models like Arctic is complex and riddled with memory and efficiency challenges. Thanks to the open-source community, we’ve had a solid foundation to build upon.

In this blog post, we shared our exploration and solution to overcoming bottlenecks that arise during large MoE model training. By integrating system design with MoE architecture considerations, and developing targeted system optimizations, we’ve been able to make significant strides in this area.

We know there is still a lot of room for improvement, but we hope that our discussion here illuminates the path towards innovative and efficient MoE training system design. Let’s continue this journey together! 🚀

Learn more in our Snowflake Arctic series

Check out our other blog posts that dive into Snowflake Arctic training, including the data cleaning, training system design, modeling and system design for optimal throughput, etc. Stay tuned as more updates will continue to drop in the Snowflake Arctic cookbook catalog.

--

--