Efficient training and usage of NLU and NLG models

Vilen Jumutc
7 min readJun 21, 2019

--

Introduction

How many of us have tried to train a simple chatbot on a handful of ad-hoc expressions? From the first glance it might look like a simple and futile exercise. But when it comes down to language models and good transfer and generalization capabilities we cannot rely on any existing out-of-the-box solution. One should carefully analyze all implications and requirements which are implied by training from scratch and using state-of-the-art (SOTA) language models, e.g. BERT [1].

To understand deeper the “dark“ matter of the latest SOTA language models lets first dissect a very popular BERT architecture and the way Google AI Language team trains it. Lets have a look at the Encoder block in Figure 1:

Figure 1: The Encoder block preceded by embedding layers

As we can see there are several crucial parts which enable a successful application of the Transformer architecture here:

  1. Input, segment and positional embeddings (encodings) are providing us with the initial numerical representation of input sequences.
  2. Multi-head (self)-attention is enabling contextual pairing and representation of words (tokens) across sequences.
  3. Layer (token-wise) norm is making deep-learning training healthy and controlled.
  4. Adding residuals helps to remember what was already learned at the previous block.
  5. Feed Forward layer helps to learn non-linear hierarchical features.

Stacking N such blocks brings us to the Encoder-based Transformer architecture which heavily utilizes Multi-Head attention where every head can be computed in parallel and represents Scaled Dot-Product attention [2].

Figure 2: Inputs, masking and outputs of the BERT model

What makes BERT approach a differentiator is the way Googlers pretrain the Transformer encoders and perform fine-tuning to the specific GLUE problems [3]. In brief to enable unsupervised learning of token representations from an existing barrage of text corpora Google AI Language team decided to perform an independently sampled masking of input tokens as shown in Figure 2 and jointly train the model towards two different optimization objectives. One of them being represented by the Masked Language Model (MLM) and the other one is Next Sentence Prediction (NSP). The MLM objective asks the model to predict not the next word for a sequence of words but rather random words from within the sequence. The NSP objective asks the model to predict if the second sentence follows the first one in a corpus or not.

Training BERT from scratch

Our main goal in the company regarding BERT was to pretrain French and Dutch language models and to use the stacked encoder blocks as an embedding layer for chatbots and intent classification models. To start with we decided to carefully preprocess available data sources: Wikipedia corpora for Dutch and French and an additional SoNaR Dutch [4] corpus comprised out of 500 million words.

Preprocessing

The preprocessing stage was divided into several steps and was tailored towards the requirements of the BERT model, e.g. sentence tokenization. Let’s list our preprocessing efforts in detail:

  1. Sentence level tokenization by the most common punctuation stop symbols: '.', '!', '?'
  2. Word level tokenization by the whitespace.
  3. Token cleaning by removing invalid characters and accents.
  4. Token normalization to lower-case.
  5. Token filtering by the minimum occurrence threshold (in the aggregated corpus) of 50x times.

All of the aforementioned preprocessing steps contribute to the embedding layers of the model, e.g.:

  1. Sentence level tokenization guarantees segment embedding which is paramount for NSP optimization objective and is necessary for Natural Language Inference (NLI) type of problems.
  2. Word level tokenization guarantees a ubiquitous word2vec embedding.

Lessons learned

While we provisioned training on the available TPU quota all of our preprocessing efforts were executed on a VM with a limited memory and CPU resources (64GB of RAM and 32 vCPUs). To make use of all resources we had to take care of and carefully rethink distribution and preprocessing strategies.

Here is the list of lessons learned while efficiently setting up the preprocessing:

  1. To estimate the dataset size (e.g. number of rows) use Linux utilities like awk and wc through the subprocess python package.
  2. Distribute your preprocessing across available cores with multiprocessing python module.
  3. Try not to pass any preprocessed data between master and child processes. Use a filesystem instead.
  4. Beware of memory issues when loading files in memory. Use a bare minimum python SDK built-in tools like enumerate(f)to iterate over the file with the least possible impact.
  5. Beware of forking any child processes (with subprocess or multiprocessing) when the memory limit is close because of the Unix os.fork() system requirement to have enough memory for mmap(copy-on-write).
  6. Try to de-allocate heavy objects and variables from memory when possible with locals().pop('data')or del data then forcing Garbage Collector (GC) by gc.collect().

Distributed training with TPUs

While adjusting our code to the TensorFlow’s TPUDistributionStrategy we encountered many unforeseen difficulties which we will shortly debunk. Many of the problems related to the usage of TPUs are linked with the preemptible TPU quota which we had to deal with.

Figure 3: TPU rack in the data center

Here is the list of problems and solutions while distributing our BERT model training:

  1. Problem: TPU can be preempted at any time. Solution: handle tf.errors.UnavailableError and tf.errors.InvalidArgumentError errors by wrapping your code with “CRUD” operations via gcloud compute tpus utility command and add model warm start and restore capabilities to the training pipeline.
  2. Problem: use of Keras framework and existing BERT implementations, like [5], with TPUDistributionStrategy is problematic due to failing imports. Solution: monkey-patch Keras imports to point to the TensorFlow’s Keras contrib packages which are aware of TPUDistributionStrategy.
  3. Problem: track all your experiments across all available TPUs. Solution: use existing frameworks, like [6], to track down experimentation setup, parameters, logs and results (we used MongoDB as a back-end and available in Sacred dashboard [6] to nicely visualize all the results).

Training adjustments and improvements

As we used not a native implementation of BERT and some completely new corpora/datasets we had to change our learning rate schedules as well. This step was crucial to achieve a proper convergence and stability. A major differentiator with the previous results on BERT was the usage of keras.callbacks.ReduceLROnPlateau strategy for adjusting the learning rate based on the validation loss (in comparison to polynomial decay rate in the original paper). We have set the patience period to 10 epochs after which the learning rate was multiplied by 0.2 factor. Additionally we used keras.callbacks.EarlyStopping to prematurely stop training when there is no improvement in validation loss consecutively for 20 or more epochs. Another big difference was the input feed to the training. We have set the batch_size to the maximum possible number of 8 samples which was still fitting the memory of TPUv2.

Following our strict TPU computational constraints we have reduced an original BERT model size to 6x stacked encoders with 6-head attention layers and embedding dimension of each attention head to 600. Feed-forward layers' dimension was strictly limited to 1024 and the input sequence length to 256tokens.

Training results and convergence

In the figures below (Figure 4–5) one can see a convergence of the validation and train loss metrics for the learning rate of 1e-4. For visualization reasons we have excluded the first epoch from the figure for the train loss as it introduces a large spike and doesn’t help to perceive the evaluation of the metric.

Figure 4: Convergence of the validation loss
Figure 5: Convergence of the train loss

As we can see convergence of the losses is steady although at the later epochs for validation loss it starts slowing down. We also observed that at some point in time we cannot gain any significant improvements in the validation loss no matter how fast we decrease the learning rate. This might be an indicator of the overfitting to the particular optimization objective. We tried different initial learning rates and the best one along with other hyperparameters was chosen for the final dry-run.

Natural Language Generation experiments with BERT

As the generalization and small corpora is such a ubiquitous problem in the chatbot industry we focused our efforts on augmenting datasets with artificially generated samples. To approach this problem one might take a closer look at BERT and capabilities which can be considered as a natural extension of the Masked Language Model. Here we mask tokens not randomly but within one sentence consecutively such that the ending is obfuscated (as it is shown in Figure 6 below).

Figure 6: Proposed masking scheme for Natural Language Generation

This methodology for generating sequences is tailored towards reconstruction of input sequences at the MLM objective. We use only the later one as our final training objective omitting the irrelevant NSP term. At generation step we append [MASK] tokens to the randomly cropped existing expressions from the chatbot corpora and successfully augment our training examples.

Of course some additional checks are performed for the validity of generated expressions. This can be broken down into automated approach using some existing NLP tools, e.g. grammar parsing with NLTK [7], and user-based validation.

References

[1] https://arxiv.org/abs/1810.04805

[2] https://papers.nips.cc/paper/7181-attention-is-all-you-need.pdf

[3] https://gluebenchmark.com

[4] https://ivdnt.org/taalmaterialen/2026-tstc-sonar-corpus

[5] https://github.com/CyberZHG/keras-bert

[6] https://github.com/IDSIA/sacred

[7] http://www.nltk.org

Postscript

This is my first Medium post so feel free to point at any inconsistencies, errors and mistakes. Also everyone is welcome to comment and ask any questions about my experience and explorations on the presented topic!

--

--