Fine-tuning the multilingual T5 model from Huggingface with Keras
Multilingual T5 (mT5) is the massively multilingual version of the T5 text-to-text transformer model by Google. It is pre-trained on the mC4 corpus, covering 101 languages! However, since mT5 was pre-trained only unsupervisedly it has to be fine-tuned before it is usable on any downstream task, unlike the original T5. In this tutorial, we will build a simple notebook, fine-tuning the mT5 with Keras. The showcased pre-processing procedures are applicable to many other models distributed through the Huggingface Hub. Our main goal is to show you a minimalistic approach for training text generation architectures from Huggingface with Tensorflow and Keras as the backend.
Setup
We will be using TensorFlow, datasets, and transformers, so make sure you have them installed in your environment.
Data preprocessing
As a use-case example, let’s use the Multilingual spell correction competition from Kaggle. Its dataset structure is as follows:
We need to tokenize the Text samples as inputs and the Expected values as labels. To do that, let’s create a helper function to pass dataset values through the initialized tokenizer and pad and truncate the outputs.
Next, we need an instance of tf.data.Dataset
to easily use it with the Keras API. We can call model.prepare_tf_dataset
to get a TensorFlow dataset compatible with the specific input requirements. In the case of mT5 model_inputs["input_ids"]
and model_inputs["labels"]
are essential (we already prepared them in preprocess_function
).
We need to specify a collator for the dataset transformation. Data collators are objects that will form a batch by using a list of dataset elements as input. The used DataCollatorForSeq2Seq can pad the sequences automatically and group them into batches. However, the padding we already did in preprocess_function
is a manual way of specifying custom pre-processing strategies. For example, in Huggingface’s example for translation, they use a padding token -100
to avoid computing loss over the padding output tokens. For more details, check the example here.
Training
We can now use the tf_train_dataset
which is compatible with the TFMT5ForConditionalGeneration
model and Keras’ API as simple as:
Full demo and inference examples can be found here.
You can also inherit TFMT5ForConditionalGeneration
in your own custom model and override the train_step
function to optimize an objective specified with the GradientTape API. For more on that topic, refer to this notebook.
Further reading
You can check out this article by David Dale on how you can cut off unused weights from the mT5 to focus on a specific language.