Two minutes NLP — Switch Transformers and huge sparse language models
Mixture of Experts, the Switch FFN layer, and scaling properties
Hello fellow NLP enthusiasts! As language models become bigger and bigger, there have been attempts at training sparse language models with even more parameters than the famous dense language models like GPT-3. In this article we see some details about how this is possible. Enjoy! 😄
Mixture of Experts models
In machine learning, models typically reuse the same parameters for all inputs. Mixture of Experts (MoE) models instead select different parameters for each incoming example, where the parameters are grouped in independent clusters called experts. At each prediction, only a subset of all the experts is used.
The result is a model with sparse activations and a huge number of parameters, but with a constant computational cost as the total number of activated neurons is always the same. MoE models achieved notable successes but their adoption has been hindered by complexity, communication costs, and training instability.
The Switch Transformer
The Switch Transformer aims at addressing the issues related to MoE models by simplifying their routing algorithm (i.e. the part of the model that decides which expert to use) and designing improved models with reduced communication and computational costs.
The guiding design principle for Switch Transformers is to maximize the parameter count of a Transformer model in a simple and computationally efficient way. The result is an increase in the scale of neural language models achieved by efficiently combining data, model, and expert-parallelism to create models with up to a trillion parameters.
The new Switch Transformer Encoder Block
The new Switch Transformer encoder block has a difference from the standard Transformer encoder block: it replaces the dense feed-forward network (FFN) layer with a sparse Switch FFN layer (represented by the light-blue box in the image). The layer operates independently on the tokens in the sequence. The diagram shows two tokens (x1 and x2) being routed across four FFN experts, where the router independently routes each token. The switch FFN layer output is then the selected FFN output multiplied by the router gate value.
Other contributions of the Switch Transformer
- Improved pre-training and fine-tuning techniques with mixed-precision training, an initialization scheme that allows for scaling to a larger number of experts, and an increased expert regularization that improves sparse model fine-tuning and multi-task training.
- Sparse models can be distilled into small dense models, reducing the model size by up to 99% while preserving 30% of the quality gains of the large sparse teacher.
Benchmarks
The models are pre-trained on C4 (Colossal Clean Crawled Corpus) and then finetuned for downstream tasks on the GLUE and SuperGLUE benchmarks.
Switch Transformer performs better than the strongly tuned T5 model where it measures 7x+ pre-training speedups while still using the same FLOPS per token.
Moreover, the improvements hold even with limited computational resources, using as few as two experts. These improvements extend into multilingual settings where gains are measured over the mT5-Base version across all 101 languages.
The next picture highlights the scaling properties of the Switch Transformer as the number of experts increases.
Switch Transformer discussion
The paper contains an interesting discussion section in FAQ form, which I report here.
Isn’t Switch Transformer better due to sheer parameter count? Yes, and by design! Parameters, independent of the total FLOPs used, are a useful axis to scale neural language models. Large models have been exhaustively shown to perform better (Kaplan et al., 2020). But in this case, our model is more sample efficient and faster while using the same computational resources.
I don’t have access to a supercomputer — is this still useful for me? Though this work has focused on extremely large models, we also find that models with as few as two experts improves performance while easily fitting within memory constraints of commonly available GPUs or TPUs (details in Appendix D). We therefore believe our techniques are useful in small-scale settings.
Do sparse models outperform dense models on the speed-accuracy pareto curve? Yes. Across a wide variety of different models sizes, sparse models outperform dense models per step and on wall clock time. Our controlled experiments show for a fixed amount of computation and time, sparse models outperform dense models.
I can’t deploy a trillion parameter model — can we shrink these models? We cannot fully preserve the model quality, but compression rates of 10 to 100x are achievable by distilling our sparse models into dense models while achieving ≈30% of the quality gain of the expert model.
Why use Switch Transformer instead of a model-parallel dense model? On a time basis, Switch Transformers can be far more efficient than dense-models with sharded parameters (Figure 6). Also, we point out that this decision is not mutually exclusive — we can, and do, use model-parallelism in Switch Transformers, increasing the FLOPs per token, but incurring the slowdown of conventional model-parallelism.
Why aren’t sparse models widely used already? The motivation to try sparse models has been stymied by the massive success of scaling dense models (the success of which is partially driven by co-adaptation with deep learning hardware as argued in Hooker (2020)). Further, sparse models have been subject to multiple issues including (1) model complexity, (2) training difficulties, and (3) communication costs. Switch Transformer makes strides to alleviate these issues.
Conclusions and next steps
In this article, we learned about Mixture of Experts (MoE) models, what are their limitations and how Switch Transformers address them. We peeked into the new Switch Transformer encoder block and read about benchmarks against T5.
Possible next steps are:
- Read the Scaling Laws for Neural Language Models paper.
- Learn about the Gopher language model performant improvements with scale.