Compressing Wav2vec 2.0

Georgian
Georgian Impact Blog
11 min readMar 24, 2021

By Zilun Peng, Akshay Budhkar, Jumana Nassour, Ilana Tuil and Jason Levy

Huge transformer models have revolutionized the AI landscape. Big language models such as BERT [5], GPT3 [6] , and ULMFiT [7] have disrupted the NLP world effectively replacing the prior methods for all NLP tasks.

However, these improvements come with a cost. Recent papers such as works by Strubell et al. [8] and Bender et al. [9] warn of costs of training and using these giant models to the environment.

In a business setting these costs are also prohibitive due to giant cloud computation bills associated with slow, memory guzzling models and GPU costs in training. For BERT, a common solution was suggested by compressing the model to create a smaller model with lower costs while maintaining the accuracy [10, 11].

The recent advances of transformer models for transcription lead to the same need, but how do we compress an ASR model? To this date, a Google search for distilling speech models doesn’t give us any answer. So we set out to find one, and share it with you.

ICYMI, some background on wav2vec 2.0

The speech-processing model wav2vec 2.0 [1] generates latent representations for speech audio waveforms. Wav2vec 2.0 is a large model, with up to 317 million parameters. In this post, we’ll show how to compress a wav2vec 2.0 model with little performance loss, speeding up inferencing for easier deployment in a production environment. For more information on wav2vec 2.0, take a look at our previous post.

What is model compression?

Model compression entails compressing a machine learning model into a smaller model without losing too much performance. Suppose we have a speech processing model (the original model), which converts audio waveforms into text. A compressed model should be smaller than the original model. Given the same speech audio waveform, a good compressed model generates text that is as close to the original model as possible.

Image credit: https://www.pngaaa.com/detail/284115

Some common model compression techniques are: pruning, quantization, and knowledge distillation. Let’s look at how to use them to compress the wav2vec 2.0 model.

Why should we compress wav2vec 2.0? And how?

Wav2vec_big_960h is a wav2vec 2.0 model trained with 960 hours of unlabeled data from the LibriSpeech dataset, and then fine-tuned with the labeled version of the same 960 hours. The table below shows its word error rate (WER) performance on the dev-clean dataset from LibriSpeech, as well as its size and number of parameters. For more information on wav2vec 2.0, take a look at our next post.

Obviously, wav2vec_big_960h is a large model and we want to compress it so the compressed model is smaller, with fewer parameters . A good compressed model scores a WER close to 2.63% and spends less than 123 seconds on inferences.

Distilling knowledge in wav2vec 2.0

In this section, we introduce knowledge distillation and look into the code to calculate knowledge distillation loss. Then, we look at how to distill knowledge in wav2vec 2.0 and see some results. The code for knowledge distillation is in kd_training.py and you can find a demo that applies knowledge distillation on wav2vec 2.0 in this notebook. We used the PyTorch framework in our sample code snippets.

What is knowledge distillation?

The goal of knowledge distillation is to extract knowledge from a teacher model into a student model. Typically, the teacher model is a large neural network and the student model is a smaller version with fewer layers. Information compresses and moves from the teacher model into the student model through knowledge distillation training.

The figure below illustrates the knowledge distillation training process. We feed the same input data into both the teacher and student model, then calculate the loss using the output from the teacher and student models. Then, we update the student model by doing gradient descent on the calculated loss. We don’t update the teacher model.

Image credit: https://towardsdatascience.com/knowledge-distillation-simplified-dd4973dbc764

Student model architecture

Wav2vec 2.0 has two major components: CNN layers and transformer layers. Our student wav2vec 2.0 model has fewer transformer layers than the teacher model, but the same number of CNN layers. It’s possible to reduce the number of CNN or transformer layers when designing the student model. Here, we choose not to reduce CNN layers because there are only seven, and they are not a huge computational bottleneck compared to transformer layers.

The architectures of the student and teacher models are defined in student_wav2vec2.py and teacher_wav2vec2.py.

Knowledge distillation loss

In this section, we walk you through the code for calculating knowledge distillation loss.

Both the student and teacher models output probability distributions over the possible tokens. Knowledge distillation loss is measured using the Kullback–Leibler (KL) divergence between probability distributions of the student and the teacher model. We want the student model to have a similar probability distribution to the teacher model.

The process of computing the probability distribution is the same for both the teacher and student models: we take the output of the transformer layers, then pass it to a linear layer, followed by a softmax operation. We need to first get the probability distribution of the teacher and the student model, then calculate the loss.

We start by setting the teacher model to evaluation mode, because we only need the inference result from the teacher model. In PyTorch, you set a model to evaluation mode by setting dropout and batch normalization layers to evaluation mode, otherwise, the inference results are inconsistent.

To calculate the teacher model’s probability distribution, we pass the input (batch in the code below) into the teacher model. batch is a tuple containing raw audio waveform, padding mask, and other arguments. When the batch size is larger than 1, we pad the input waveform to the length of the longest waveform in the batch. padding_mask specifies the part of the input waveform to mask. By setting torch.no_grad(), we ensure that gradients are not calculated in the forward pass.

Apply a similar process to the student model.

Next, calculate the log probability for the student model since we need it to calculate knowledge distillation loss below. Unlike the teacher model, the student model is in training mode and gradients are computed during the forward pass.

Finally, we are ready to compute the knowledge distillation loss! The knowledge distillation loss is the KL divergence between probability distributions of the teacher model and the student model. The code below shows how to calculate it.

torch.nn.functional.kl_div(student_log_prob, teacher_prob, reduction=’batchmean’) * (self.temperature**2)

The temperature term in the code above is the same one Hinton et al. [3] use in their knowledge distillation loss. A higher temperature results in a softer probability distribution. We calculated student_log_prob and teacher_prob using the temperature term, so we only need to scale the loss with the temperature here. Hinton et al. multiply the KL divergence by temperature², so the overall loss is not too small and gradient values are sufficiently large. This ensures the student model learns efficiently. In our experiments, we used a temperature of 1. We did not experiment with other values, since we received good results with it.

Why did we use student model log probabilities? KL divergence between probability distributions P and Q over the space X is defined as: KL(P || Q) = \sum_{x \in X} P(x) log( P(x) ) — P(x) log( Q(x) ). F.kl_div works as follows:

F.kl_div(student_log_prob, teacher_prob) = teacher_prob * [log(teacher_prob) — student_log_prob]. 

Therefore, we need to use student model log probabilities to match the definition of KL divergence.

Our results

The following table shows trained student model performance versus its teacher model, which is the original wav2vec 2.0 model.

The student model is 4.8 times smaller than the original model, with a 7% loss in word error rate. On GPU, the student model is 2.4 times faster than the teacher model. On CPU, the student model is 2.8 times faster. Distributed inference would further speed up the student model’s inference.

How did we train the model?

In this section, we share some details on training to get the results above.

Besides knowledge distillation loss, we have a feature penalty term in our objective function. Our overall objective is: knowledge distillation loss + feature penalty. The goal of feature penalty is to regularize CNN layers. In the code below, features is the output of CNN layers. We take the square of features, then calculate the mean, so feature penalty is the L2 regularization term.

We used the Adam optimizer with weight decay. The learning rate starts at 2.5e-5. We warm up the learning rate by increasing it in the first two epochs, then linearly decrease the learning rate in every epoch.

The code below shows how we configure the optimizer and learning rate scheduler. We define a lambda function, lr_lambda to determine how to schedule the learning rate. If we have not finished running the first two epochs (i.e. current_epoch < self.num_lr_warm_up_epoch ), we increase the learning rate, otherwise we decrease the learning rate.

Our student model has the same number of CNN layers as the teacher model. It only has four transformer layers versus 24 in the teacher model.

Initialization is crucial to the student model’s performance. We first trained a larger student model with 12 transformer layers, initialized from a pre-trained model with 24 layers. We initialized every transformer layer in this model by taking every second transformer layer from the pre-trained model. It scored a 6% WER on the validation set. Then, we initialized every transformer layer of the student model (with 4 transformer layers in total) by taking every third layer from the larger student model (with 12 transformer layers).

The code above shows how to initialize the student model.

num_trans_layer_student_model is the number of transformer layers in the student model, which is 4 in the example above.

num_trans_layer_student_init_model is the number of transformer layers in the large model used to initialize the student model, 12 in the example above.

student_init_model_selected_transformer_layers lists layer indices from the large model that will be used to initialize the student model. In our example, these are transformer layers 0, 3, 6 and 9 from the large model.

student_model_transformer_layers lists all transformer layer indices for the student model. In our example, student_model_transformer_layers = [0,1,2,3].

In the for loop, we take each pair of layer indices from student_model_transformer_layers and student_init_model_selected_transformer_layers, and use the selected transformer layer from the large model to initialize the corresponding transformer layer in the student model.

We used all training data from LibriSpeech to train the student model including train-clean-100, train-clean-360, and train-other-500. We trained for 32 epochs.

Pruning wav2vec 2.0

What is pruning?

Pruning generally requires setting a criteria on the weights of a model, and setting certain weights to 0 if they match that criteria.

Sensitivity pruning

We used Distiller’s sensitivity pruner to prune the wav2vec 2.0 model. To use the pruner, set a sensitivity (a numerical value) for each layer, and it prunes weights smaller than the product of the sensitivity and standard deviation of weights on that layer.

Our results

Our pruning results show that a wav2vec 2.0 model with a 23% sparsity has similar performance to a non-pruned model. We did not find any improvement in inference speed with a pruned model, which we expected because pruning only sets certain weights to 0.

Quantizing wav2vec 2.0

What is quantization?

Parameters of a neural network are usually represented as 16-bit or 32-bit floating point numbers. The classic way to compress the neural network is to quantize its parameters into 8-bit integers. The resulting quantized neural network is smaller than the original model.

Static quantization and dynamic quantization

PyTorch provides three types of quantizations: dynamic quantization, static quantization, and quantization-aware training. Dynamic quantization only quantizes weights and dynamically quantizes activations in each forward pass. Static quantization quantizes both weights and activations before inferences, requiring calibration before inference to determine parameters for quantizing activations [4]. Next, we’ll see how static and dynamic quantization can and cannot be applied to quantize wav2vec 2.0.

Static quantization

PyTorch does not support quantization for gelu activation (as of December 2020), which is used in wav2vec 2.0. A closed pull request shows they were trying to add the support but suspended it temporarily due to some test case failures. One possible work around is to dequantize before gelu and re-quantize afterwards, but this causes some performance degradation.

Dynamic quantization

As of December 2020, there are no supports for dynamically quantizing the multi-head attention mechanism, which is an essential part of a transformer layer. Related issues remain open in pytorch and fairseq.

As a result, tensors entering multi_head_attention_forward must not be quantized. We observed that dynamic quantization wraps weights and biases with some methods, requiring a costly linear_unpackoperation to access them during inference. We save weights and biases before inference (see the code below), so that linear_unpack is not performed for every sample during inference, increasing inference efficiency.

In the code below, we access and save weights and biases of q_proj, k_proj, v_proj, in_proj, out_proj . They are components inside the multi-head attention mechanism, and multi-head attention is part of a transformer layer.

It’s simple to apply dynamic quantization to wav2vec 2.0, as you see in the code below.

In the first line, we specify that we want to quantize linear layers of the wav2vec 2.0 model. Note that PyTorch only supports quantizing linear, LSTM, LSTMCell, GRUCell, and RNNCell as of December 2020. In the second line, we prepare the quantized model for inference, as we explain in the content above. For a demo on how to quantize wav2vec 2.0, check out this notebook.

The wav2vec_big_960h model scores a WER of 2.63% on dev-clean while spending 4,433 seconds on inferences. Its dynamically quantized version spends 4,079 seconds and has a WER of 2.75% on dev-clean. The quantized model did not significantly speed up inferences, but is much smaller than the original model. Therefore, the quantized model can be useful when model size is an issue.

PyTorch only supports quantization on the CPU, so we reported CPU inference times in the above table.

Conclusion

In this post, we introduced model compression and three common techniques: knowledge distillation, pruning, and quantization. We also talked about how to compress wav2vec 2.0 using these three techniques. Knowledge distillation creates a compressed model with fast inference, but with some performance losses. The quantized model does not have a performance loss, but its inference speed is not much faster than before. We will not get a faster model by pruning alone. Hopefully, this will help you choose a compression technique for wav2vec 2.0.

What’s next?

We showed you how to compress wav2vec 2.0 for faster inference in this post. Read our next post on inference to learn how to use distributed inference and make inference even more efficient with wav2vec 2.0! If you are interested in compressing a wav2vec 2.0 model using knowledge distillation, check out this notebook. If you are interested in quantizing wav2vec 2.0, check out this notebook.

References

[1] Baevski et al. (2020). Wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations. https://arxiv.org/abs/2006.11477

[2] Panayotov et al. (2015). Librispeech: an asr corpus based on public domain audio books. https://ieeexplore.ieee.org/document/7178964

[3] Hinton et al. (2014). Distilling the Knowledge in a Neural Network. https://arxiv.org/abs/1503.02531

[4] PyTorch’s quantization tutorial. https://pytorch.org/docs/stable/quantization.html

[5] Devlin et al. (2018) BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding. https://arxiv.org/abs/1810.04805

[6] Brown et al. (2020) Language Models are Few-Shot Learners. https://arxiv.org/abs/2005.14165

[7] Howard and Ruder (2018) Universal Language Model Fine-tuning for Text Classification. https://arxiv.org/abs/1801.06146

[8] Strubell et al. (2019) Energy and Policy Considerations for Deep Learning in NLP. https://arxiv.org/abs/1906.02243

[9] Bender et al. (2021) On the Dangers of Stochastic Parrots: Can Language Models Be Too Big? https://faculty.washington.edu/ebender/papers/Stochastic_Parrots.pdf

[10] Sanh et al. (2019) DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter. https://arxiv.org/abs/1910.01108

[11] Zafrir et al. (2019) Q8BERT: Quantized 8Bit BERT. https://arxiv.org/abs/1910.06188

--

--

Georgian
Georgian Impact Blog

Investors in high-growth business software companies across North America. Applied artificial intelligence, security and privacy, and conversational AI.