Fine-tuning the multilingual T5 model from Huggingface with Keras

Radi Cho
2 min readFeb 18, 2023

--

Multilingual T5 (mT5) is the massively multilingual version of the T5 text-to-text transformer model by Google. It is pre-trained on the mC4 corpus, covering 101 languages! However, since mT5 was pre-trained only unsupervisedly it has to be fine-tuned before it is usable on any downstream task, unlike the original T5. In this tutorial, we will build a simple notebook, fine-tuning the mT5 with Keras. The showcased pre-processing procedures are applicable to many other models distributed through the Huggingface Hub. Our main goal is to show you a minimalistic approach for training text generation architectures from Huggingface with Tensorflow and Keras as the backend.

Check out the original T5 paper: https://arxiv.org/abs/1910.10683.

Setup

We will be using TensorFlow, datasets, and transformers, so make sure you have them installed in your environment.

Download the Tensorflow-flavoured mT5 model and its respective tokenizer.

Data preprocessing

As a use-case example, let’s use the Multilingual spell correction competition from Kaggle. Its dataset structure is as follows:

Id — unique id for each sentence pair, Language — the language of the sentence, Text — noisy text, Expected — sentence to be reconstructed.

We need to tokenize the Text samples as inputs and the Expected values as labels. To do that, let’s create a helper function to pass dataset values through the initialized tokenizer and pad and truncate the outputs.

This function can now be applied to all samples in the initial dataset.

Next, we need an instance of tf.data.Dataset to easily use it with the Keras API. We can call model.prepare_tf_dataset to get a TensorFlow dataset compatible with the specific input requirements. In the case of mT5 model_inputs["input_ids"] and model_inputs["labels"] are essential (we already prepared them in preprocess_function).

We need to specify a collator for the dataset transformation. Data collators are objects that will form a batch by using a list of dataset elements as input. The used DataCollatorForSeq2Seq can pad the sequences automatically and group them into batches. However, the padding we already did in preprocess_function is a manual way of specifying custom pre-processing strategies. For example, in Huggingface’s example for translation, they use a padding token -100 to avoid computing loss over the padding output tokens. For more details, check the example here.

Training

We can now use the tf_train_dataset which is compatible with the TFMT5ForConditionalGeneration model and Keras’ API as simple as:

Note that the model’s internal loss computation will be used as the loss for model.fit as this is a common way to train TensorFlow models in transformers.

Full demo and inference examples can be found here.

You can also inherit TFMT5ForConditionalGeneration in your own custom model and override the train_step function to optimize an objective specified with the GradientTape API. For more on that topic, refer to this notebook.

Further reading

You can check out this article by David Dale on how you can cut off unused weights from the mT5 to focus on a specific language.

--

--

Radi Cho

Google Developer Expert in Machine Learning. Forbes Bulgaira "30 under 30".