[Paper] Adafactor: Adaptive Learning Rates with Sublinear Memory Cost

Essential for fine-tuning T5 v1.1 and mT5 models

Ceshine Lee
Veritable
5 min readApr 18, 2021

--

Photo Credit

Originally published at https://blog.ceshine.net. Some of the mathematical expressions and explanations are removed in the Medium version. Please check the link for the complete article.

Motivation

The Adafactor optimizer, in my experience, can provide much better convergence than fine-tuning the T5 v1.1 and mT5[1] pre-trained models. However, I encountered problems when using a custom learning rate scheduler with the Adafactor implementation from the huggingface/transformer library. I combed through the paper and the source code to find and fix the cause of the problem, which turned into a tiny contribution to the library.

To further squeeze value from the time I’ve invested, I wrote this post to introduce the key ideas of the Adafactor optimizer and analyze the corresponding chunk of code in the huggingface/transformer implementation (which was taken from the fairseq library). Working examples as Kaggle notebooks are also provided: T5 v1.1 and mT5.

(Notes: For the original T5 pre-trained models[2], which were pre-trained with a mixture of unsupervised and supervised objectives, Adam or AdamW optimizers are enough to get good results.)

Overview

The popular Adam[3] optimizer keeps two additional values for each parameter. One stores the momentum; one stores the exponentially smoothed squared gradients. Therefore, the memory requirement is tripled comparing to the vanilla SGD optimizer. Adafactor dramatically reduces this requirement (more than half) while retaining comparable performance (tested on the WMT ’14 En→De translation task with the classic transformer seq2seq architecture).

The authors of Adafactor firstly propose to replace the full smoothed squared gradients matrix with a low-rank approximation. This reduces the memory requirements for the square gradients from O(nm) to O(n+m).

Secondly, Adafactor removes momentum entirely. This causes some training instability. The authors think that the out-of-date second-moment accumulator (the exponential smoothing of the squared gradients) might be the cause. By increasing the decay rate with time (new values have higher importance) and clipping the gradient update, Adafactor can converge normally even without momentum.

Finally, Adafactor multiplies the learning rate by the scale of the parameters (this is called “relative step size”). The authors showed that training with relative step sizes provides more robustness to differently scaled embedding parameters.

Factored Second Moment Estimation

Adafactor refactor the exponential moving average of squared gradients V∈R^n×m to RS, where R∈R^n×1 and S∈R^1×m. It has an analytic solution for minimizing the I-divergence (generalized Kullback-Leibler divergence):

The solution only requires us to store the moving averages of the row sums and the column sums:

Alg 1: The low-rank approximation of V

Looking at the corresponding part in the implementation, here’s the part that update the moving average of the row and column sums:

And the implementation of the analytic solution ( rsqrt means the reciprocal of the square root 1/sqrt(input)):

exp_avg_sq_row.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-1))
exp_avg_sq_col.mul_(beta2t).add_(1.0 - beta2t, update.mean(dim=-2))

Removing Momentum

The authors demonstrated that fast decay of the second moment estimator has convergence problems, while slow decay has stability problems:

Table 1

And the problem of slow decay is the larger-than-desired updates:

Update Clipping

One of the proposed solutions is to clip the update according to the root-mean-square over all parameters in a weight matrix or vector:

(Details skipped here. Please refer to the complete post.)

Increasing Decay Parameter

Another solution is to use an increasing β_2. The proposed family of schedules is:

(Details skipped here. Please refer to the complete post.)

Relative Step Size

Adafactor multiplies the given learning rate by the scale of the parameters, which is defined as the root-mean-square of its components. Therefore, parameters with bigger values get bigger updates and vice versa:

The paper calls α_t the “absolute step size” and ρ_t the “relative step size.”

(Details skipped here. Please refer to the complete post.)

Now we have the complete Adafactor algorithm:

(More implementation details skipped here. Please refer to the complete post.)

Working Examples

I’ve written some code that fine-tunes T5 and mT5 models on NLI datasets using PyTorch Lightning. This is where I set up the Adafactor optimizer:

optimizer = Adafactor(
self.model.parameters(),
relative_step=False, warmup_init=False,
clip_threshold=1.0, lr=self.config.learning_rate,
scale_parameter=True
)

I used a combination of linear warmup and cosine annealing to schedule the learning rates:

scheduler = {
'scheduler': pls.lr_schedulers.MultiStageScheduler(
[
pls.lr_schedulers.LinearLR(
optimizer, 0.0001, lr_durations[0]),
CosineAnnealingLR(optimizer, lr_durations[1])
],
start_at_epochs=break_points
),
'interval': 'step',
'frequency': 1,
'strict': True,
}

I’ve published a Kaggle notebook that fine-tunes the google/t5-v1_1-base model on the MultiNLI dataset and gets a competitive result. I've observed that my learning rate schedule performs better than the inverse-square root decay recommended by the paper.

An mT5 version that further fine-tunes an MNLI fine-tuned google/mt5-base model on a multi-lingual dataset is also available. Because of the low resource of the multi-lingual corpus, I froze the embedding matrix in this one to prevent overfitting.

References

  1. Xue, L., Constant, N., Roberts, A., Kale, M., Al-Rfou, R., Siddhant, A., … Raffel, C. (2020). mT5: A massively multilingual pre-trained text-to-text transformer.
  2. Raffel, C., Shazeer, N., Roberts, A., Lee, K., Narang, S., Matena, M., … Liu, P. J. (2019). Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer.
  3. Kingma, D. P., & Ba, J. L. (2015). Adam: A method for stochastic optimization.

--

--

Ceshine Lee
Veritable

Data Geek. Maker. Researcher. Twitter: @ceshine_en