Machine translation from scratch with MXNet and R

Jeremie Desgagne Bouchard
Apache MXNet
Published in
8 min readMar 13, 2019

In this post, we’ll see how to develop a complete machine translation system from scratch using the MXNet R package. We’ll achieve a BLEU score over 28 on an English to French task with a single model trained for a day on a single GPU, without relying on any outside resources such as pre-trained word vectors or tokenizers. Full code to reproduce the model can be found in the translatR repo.

Acquiring data

A first step to build a translation model is to gather and prepare the data. Thankfully, the WMT has made a large parallel corpus available.

In this demo, the Europarl v7 and Common Crawl corpus will be used, providing over 4M sentence pairs.

download.file(url = "http://www.statmt.org/europarl/v7/fr-en.tgz", destfile = "./data/europarl_fr-en.tgz")
untar(tarfile = "./data/europarl_fr-en.tgz”, exdir = "./data/")
euro_en <- read_lines("data/europarl-v7.fr-en.en")
euro_fr <- read_lines("data/europarl-v7.fr-en.fr")

The above import results in two large vector of characters (for English and French) of the same size where each element is a sentence for the Europarl. The same is repeated for Common Crawl and results are concatenated.

Preprocessing

The model will be trained on sequences of words. Since the training will be performed by feeding the model with arrays of fixed size, the goal of the preprocessing will be to perform the following transformation:

The above involves a few steps. First is the tokenization, by which a string containing a whole sequence is split into a vector of tokens.

To be lightweight and inject as little prior knowledge of the language, sequences are simply split at empty spaces. In order for the model to handle punctuation, a quick hack is performed by inserting spaces around them. Finally, placeholder tokens <BOS> and <EOS> are concatenated to provide beginning and ending cues to the model.

source <- "I'd love to learn French!"
source <- gsub("([[:punct:]])", " \\1 ", source)
source <- paste("<BOS>", source, "<EOS>")
> strsplit(source, "\\s+")
[[1]]
[1] "<BOS>" "I" "'" "d" "love" "to" "learn" "French" "!" "<EOS>"

The above vector of tokens is then rearranged into a data.table along with a sequence and word id:

source_dt <- data.table(word = unlist(source_word_vec_list), 
seq_id = rep(1:length(source_seq_length),
times = source_seq_length),
seq_word_id = seq_word_id_source)
> source_dt
word seq_id seq_word_id
1: <BOS> 1 1
2: I 1 2
3: ' 1 3
4: d 1 4
5: love 1 5
6: to 1 6
7: learn 1 7
8: French 1 8
9: ! 1 9
10: <EOS> 1 10

The data.table format is efficient for building a dictionary to map each token to an index. Tokens are counted and rare ones are ignored to limit the vocabulary size in the 20k to 50k range.

source_word_count <- source_dt[, .N, by = word]
source_dic <- source_word_count[N >= word_count_min,,][order(-N)]

Two other special tokens are also introduced: <PAD> and <UNKNOWN>. The first is used to fill sequences shorter than the data matrix and the second is used as a default for tokens not present in the dictionary.

Once the dictionary is built, the remaining step is to index it to the above data.table and reshape it into a table of size [number of sequence, max sequence length] using dcast, forcing a common sequence length:

source_dt <- source_dic[source_dt][order(seq_id, seq_word_id)]
source_dt <- dcast(data = source_dt, seq_word_id ~ seq_id, value.var = "word_id", fill = 0)
source <- as.matrix(source[ , c("seq_word_id") := NULL])
> source_dt
1 2
[1,] 5 5
[2,] 22 22
[3,] 21 21
[4,] 550 550
[5,] 1258 1258
[6,] 8 8
[7,] 1161 1161
[8,] 424 424
[9,] 86 10
[10,] 6 1645
[11,] 0 86
[12,] 0 6

The raw text has now been transformed in a matrix of dimensions [Max Seq Length, Number of Sequences] ready to feed a translation model. The first column matches the index showed in the initial preprocessing figure.

Architecture

A basic sequence-to-sequence model can be represented as:

In this design, the encoder needs to provide a single vector of features that carries the entire sequence information, which led to the joke where an unlimited number of clowns could be pulled from that single vector. The intuition of supernatural capabilities indeed has validity: more effective designs avoid the restriction imposed by the fixed length encoding compression by relying on a clever trick: attention.

No more clowns! With an attention mechanism, the data feeding the decoder is now the entire encoded sequence, solving the information bottleneck of the previous architecture.

Attention can be implemented in many flavors. To provide flexibility, my implementation used the abstract notation of the Query-Key-Value that is described in the Attention is All you Need paper. A benefit of this approach is the resulting modularity of the translation system. The Encoder provides both the Value and Key matrix, the latter being a transformation of the former, while the Decoder provides the Query, which is the vector of features of the decoded token.

The Attention module will return a weighted average of the Value matrix (the attention vector), which will be used to enhance the vector of features during decoding.

Multiple variations on the Query-Key-Value paradigm are possible. Bilinear and MLP have been implemented in addition to the Dot attention illustrated below:

attn <- attn_dot(value=value, query_key_size=num_hidden, scale=T)
init <- attn$init()
attend <- attn$attend
attention <- attend(query=query, key=init$key, value=init$value, attn_init=init)

The above is the actual MXNet graph for the dot-attention where the batch size is 128 (last dimension). For each token to be decoded, a query is the reprojection of the 512-length representation of that token. The value is the full encoding of the source sequence. It is itself reprojected to form a key on which a dot product on the query is applied to obtain the weighting scheme to be applied to the value matrix. The resulting 512-length vector is called the context vector, which is then appended to the original token encoding to calculate the score associated to each word of the target vocabulary.

The final component of the model is the softmax loss function. It normalizes the above scores into a probability distribution and uses the cross-entropy loss function to derive the head gradient to propagate.

Training

Thanks to a modular encoder-attention-decoder design, the complete model can be build in a straightforward fashion. A remaining subtelty is that during training, the decoder takes advantage of a teacher. That is, at each step, the true previous token is fed rather than the predicted one. Such information is not available when performing inference. A second decoder is therefore built which used the most likely word at inference (argmax over the predictions) rather than the true label.

The hyper-parameters for training were kept fairly vanilla: an Adam optimizer with a decreasing learning rate:

initializer <- mx.init.Xavier(rnd_type = "uniform", factor_type = "in", magnitude = 2.5)lr_scheduler <- mx.lr_scheduler.FactorScheduler(step = 5000,     factor_val = 0.9, stop_factor_lr = 5e-5)optimizer <- mx.opt.create("adam", learning.rate = 5e-4, beta1 = 0.9, beta2 = 0.999, epsilon = 1e-8, wd = 1e-8, clip_gradient = 1, rescale.grad = 1, lr_scheduler = lr_scheduler)

The model was then trained for 8 epochs, taking about a full day on a V100 GPU.

model <- mx.model.buckets(symbol = decode_teacher,
train.data = iter_train,
eval.data = iter_eval,
num.round = 12, ctx = ctx, verbose = TRUE,
metric = mx.metric.Perplexity,
optimizer = optimizer,
initializer = initializer,
batch.end.callback = batch.end.callback,
epoch.end.callback = epoch.end.callback)
mx.model.save(model=model, prefix="models/en_fr_cnn_rnn_teacher", iteration = 8)mx.symbol.save(symbol=decode_argmax, filename="models/en_fr_cnn_rnn_argmax.json")

Perplexity is used as the evaluation metric to track the progress of the training:

Inference

To have a comparable assessment of the translation quality, the model can be benchmarked against the official WMT test set. To do so, the sacreBLEU library comes in handy:

sacrebleu --test-set wmt15 --language-pair en-fr --echo src > wmt15-en-fr.src

When performing inference on a new dataset, it’s crucial to apply the same preprocessing as for the training data. Luckily, very few transformations were applied in our scenario, making this step easily replicable on the wmt15-en-fr.src data. Obviously, the same dictionary must be applied as well, so the one developed for training will be used during the preprocessing step rather than building a new one on the fly.

The inference model is obtained by combining the argmax structure with the weights learned during training with the teacher.

model <- mx.model.load(prefix = "models/model_wmt15_en_fr_cnn_rnn_teacher_v2", iteration = 12)
sym_infer <- mx.symbol.load(file.name = "models/model_wmt15_en_fr_cnn_rnn_argmax_v2.json")
model_infer <- list(symbol = sym_infer, arg.params = model$arg.params, aux.params = model$aux.params)
model_infer <- structure(model_infer, class="MXFeedForwardModel")

The inference can then be applied on the test data, stored as a text file, ready to be evaluated by sacreBLEU:

cat wmt15_en_fr_cnn_rnn.txt | sacrebleu -t wmt15 -l en-fr

The resulting performance summary should look similar to:

BLEU+case.mixed+lang.en-fr+numrefs.1+smooth.exp+test.wmt15+tok.13a+version.1.2.1                                                            2 = 28.2 61.0/36.2/23.8/16.1 (BP = 0.930 ratio = 0.933 hyp_len = 26090 ref_len =                                                             27975)

Indicating we achieve a BLEU score of 28.2.

Test sentences can also be submitted for translation, validating the soundness of the model:

> infer_helper(infer_seq = "I'd love to learn French!",
model = model_infer,
source_dic = source_dic,
target_dic = target_dic,
seq_len = seq_len)
[1] "J'aimerais apprendre le français!"

Improving

To reach state-of-the-art performance, some additional tricks can be considered.

  • Tokenization: best performing systems typically use more sophisticated tokenization schemes, notably BPE which creates sub-word splits.
  • Positional embedding: in addition to the token ids that are used as input data, features that represent the position of the token within the sentence can be added. It can either be single position indicator (absolute or relative) or a more complex collection of sin/cos waves as used in the transformer model.
  • Model ensembling: average the predictions of a few models.
  • Beam search: rather than using the single best translated token, top N candidates and their associated next best step are generated and the token associated with the max likelihood path is kept. This partially circumvents limitations of the greedy argmax decoding.

Many of those features are integrated in the comprehensive Sockeye library built on top of MXNet.

--

--