Faster than training from scratch — Fine-tuning the English GPT-2 in any language with Hugging Face and fastai v2 (practical case with Portuguese)

Pierre Guillou
40 min readJul 14, 2020

--

The 3 main steps of fine-tuning the English GPT-2 to Portuguese with Hugging Face and fastai v2 (image edited — fast.ai NLP)
The 3 main steps of fine-tuning the English GPT-2 to Portuguese with Hugging Face and fastai v2 (image edited from fast.ai NLP)

In this tutorial, instead of training from scratch, we will see how to fine-tune in just over a day, on one GPU and with a little more than 1GB of training data an English pre-trained transformer-based language model to any another language. As a practical case, we fine-tune to Portuguese the English pre-trained GPT-2 by wrapping the Transformers and Tokenizers libraries of Hugging Face into fastai v2. We thus create a new language model: GPorTuguese-2, a language model for Portuguese text generation (and more NLP tasks…).

Other posts in the GPT-2 series: (NLP & fastai) GPT-2 | Byte-level BPE, an universal tokenizer but…

Texts generated by GPorTuguese-2 on Covid-19, Netflix, Artificial Intelligence and… unicorns

Examples of texts generated by GPorTuguese-2 (Portuguese GPT-2 small) on Covid-19, Netflix, Artificial Intelligence and… unicorns

Acknowledgment

This tutorial was made possible thanks to the computing power of the AI ​​Lab (University of Brasilia) to which I am attached as an Associate Researcher in NLP and the participation of its directors in the definition of NLP strategy, Professors Fabricio Ataides Braz and Nilton Correia da Silva. Thank you so much!

AI Lab (University of Brasilia, Brazil)
AI Lab (University of Brasilia, Brazil)

And special thanks to Sylvain Gugger for his tutorial on Transformers and fastai v2 which is the basis of this tutorial.

I would also like to mention Nama.ai R&D team, and its CEO Rodrigo Scotti, which is participating in Brazil in AI research to improve online services by the use of generative NLP models.

Table of contents

  • Texts generated by GPorTuguese-2 on Covid-19, Netflix, Artificial Intelligence and… unicorns
  • Acknowledgment
  • Notebooks, Web App and model download
  • Results
  • About the need for language models not just in English… and how to do it in real life
  • Why using fastai v2 over Hugging Face libraries to fine-tune a pre-trained transformer-based language model?
  • About the choice of GPT-2
  • Main coding steps to fine-tune a Hugging Face language model with fastai v2
  • Model sharing and uploading in the Hugging Face model hub
  • Text Generation by our Portuguese GPT-2
  • Conclusion
  • Annex | Other articles about fine-tuning GPT-2 to another language

Notebooks, Web App and model download

The main code of the tutorial is published in this post, organized by paragraph.

To obtain the complete code, simply download the notebook finetuning-English-GPT2-any-language-Portuguese-HuggingFace-fastaiv2.ipynb (nbviewer version). However, as this notebook is very detailed, use this fast notebook finetuning-English-GPT2-any-language-Portuguese-HuggingFace-fastaiv2_FAST.ipynb (nbviewer version) if you just want to run the code without explanation.

In addition, our GPorTuguese-2 (Portuguese GPT-2 small), a language model for Portuguese text generation (and more NLP tasks…), is testable online in the Hugging face model hub with all usage information at this address:

Results

Analysis of results

In a little more than a day (we only used one GPU NVIDIA V100 32GB; through a Distributed Data Parallel (DDP) training mode, we could have divided by three this time to 10 hours, just with 2 GPUs), we got a loss of 3.17, an accuracy of 37.99% and a perplexity of 23.76 (see the validation results table below and explications about perplexity at the end of the paragraph). Happy!

+------------+------+----------+------------+----------+-----------+
| after | loss | accuracy | perplexity | time | cumulative|
| ... epochs | | (%) | | by epoch | time |
+------------+------+----------+------------+----------+-----------+
| 0 | 9.95 | 9.90 | 20950.94 | 00:00:00 | 00:00:00 |
| 1 | 3.64 | 32.52 | 38.12 | 5:48:31 | 5:48:31 |
| 2 | 3.30 | 36.29 | 27.16 | 5:38:18 | 11:26:49 |
| 3 | 3.21 | 37.46 | 24.71 | 6:20:51 | 17:47:40 |
| 4 | 3.19 | 37.74 | 24.21 | 6:06:29 | 23:54:09 |
| 5 | 3.17 | 37.99 | 23.76 | 6:16:22 | 30:10:31 |
+------------+------+----------+------------+----------+-----------+
Fine-tuning of GPT-2 into Portuguese
Table of training and validation results

After a huge gain at the end of the first epoch (see validation results graph below), the validation accuracy continues to improve until the end of training but less (it goes to nearly 40%, that is considered a good performance for a language model — check these notebooks nn-vietnamese.ipynb and nn-turkish.ipynb from Jeremy Howard of fastai).

Validation loss and accuracy of pre-trained English GPT-2 of Hugging Face fine-tuned to Portuguese by fastai v2
Validation loss and accuracy of pre-trained English GPT-2 of Hugging Face fine-tuned to Portuguese by fastai v2

The perplexity evolution graph of the validation dataset confirms that the fine-tuning of the vocab and position embedding matrix in the first epoch brought a very significant gain.

Validation perplexity of pre-trained English GPT-2 of Hugging Face fine-tuned to Portuguese by fastai v2

Our results validate the importance of having firstly trained the embedding matrices (vocab and position) before the fine-tuning of the 3-layers groups (each with 4 decoder blocks).

About our fine-tuning strategy

Our Transfer Learning and fine-tuning approach to get a Portuguese GPT-2 from an English one is validated by the results obtained.

Indeed, the fact that our model quickly obtains a huge performance comes from our fine-tuning strategy on a pre-trained model, i.e. the reuse of its vocab and position embedding matrices (all token vectors in common between English and Portuguese vocabs were kept) and model weights learned on an English corpus (WebText of 40GB).

Surely, this strategy worked because the language rules between English and Portuguese are not that different (languages rules implemented into the pre-trained model in the embedding matrices and weights)!

About the perplexity of our model

To get an idea of the performance of our GPT-2 fine-tuned to Portuguese, we would need to train the same GPT-2 model on the same Portuguese dataset but from scratch (with randomized position and vocab embedding and model parameters (weights)).

However, we can already compared our 23.76 perplexity to that of 25.6 for example from the Transformers Tutorial on which Sylvain Gugger writes “25.6 as perplexity is kind of amazing” (zero-shot perplexity of the English GPT-2 with BBPE tokenizer on the WikiText2 corpus) or to that of 29.41 from the original GPT-2 paper (zero-shot perplexity of the English GPT-2 with BPE tokenizer (not a BBPE one) on the WikiText2 corpus).

Looks good!

Perplexity table from the original GPT-2 paper (Language Models are Unsupervised Multitask Learners)
Perplexity table from the original GPT-2 paper (Language Models are Unsupervised Multitask Learners)

About the need for language models not just in English… and how to do it in real life

Even if English is today the most spoken language in the world (around 1.2 billion people), the world is multilingual (for example, there are 34 languages having 45 million or more total speakers in the 2019 edition of Ethnologue, a language reference published by SIL International).

It is therefore necessary to have natural language models trained in all existing languages, and not just in English, since these models constitute the essential basis for the training of models capable of performing a particular task in linguistics (classification, Q&A, synthesis, entity searches, etc.).

This is a color coded diagram to indicate the percentage of English speakers (Image source: English Wikepedia)
This is a color coded diagram to indicate the percentage of English speakers of nearly all the world’s countries. A few small islands have not been accounted for. (image source: List of countries by English-speaking population in Wikipedia)

However, if it is extremely simple and free to download a language model trained in English via the Transformers library of Hugging Face for example, it is often much more difficult to find online a model trained in another language.

Option 1 | Fast pipeline to localize any transformer-based model to any language

The easiest way to get theses language-specific language models would be to use a pipeline of existing pre-trained transformer-based models like the following one:

Fast pipeline to localize any transformer-based model (here, a language model) to any language, for example in Portuguese
Fast pipeline to localize any transformer-based model (here, a language model) to any language, for example in Portuguese (image edited from fast.ai NLP)

For example, to obtain a Portuguese GPT-2, we could download from the Transformers library of Hugging Face the OpenAI GPT-2 pre-trained in English and the MarianMT translator (we could also use BART or T5 for the translation) in order to create the following pipeline:

       (input) Portuguese to English (MarianMT) 
>> English pre-trained language model (GPT-2)
>> (output) English to Portuguese (MarianMT)

So, for free and with only a few lines of code, we can get any language model in any language, and even any task-oriented NLP model (classification, Q&A, synthesis, entity searches, etc.) using the same pipeline. Not bad!

We will find the code of this pipeline and examples of use for text generation in the post “Fast pipeline to localize any transformer-based model to any language”.

However, the problem with this simple solution is that we depend on the quality of training of 2 pre-trained NLP models, which greatly increases the risk of losing the linguistic singularities and nuances of the desired language.

Option 2 | Fine-tuning of an existing pre-trained model

Therefore, it often becomes necessary to have to train its own language model.

Nevertheless, training from scratch a powerful transformer-based language model like GPT-2 or GPT-3 of OpenAI , BART of Facebook or T5 of Google requires tens or even hundreds of GB of text, which is impossible or difficult to find or requires power gigantic computing that only a few companies in the world have. For example,

NLP models through time, with their number of parameters (Image credit: TensorFlow blog)
NLP models through time, with their number of parameters (Image credit: TensorFlow blog)

Thus, as it is easy to download a few GB of texts from an online language corpus (Wikipedia, OSCAR, Common Crawl for example) and rent a NVIDIA V100 GPU for $1.24 an hour (GCP, AWS, Azur for example), it is more realistic for the majority of people and organizations wishing to use a language model other than English to fine-tune on few GB of texts a model already pre-trained in English (i.e. fine-tuning a model obtained by Transfer Learning) using Deep Learning frameworks such as TensorFlow+Keras or PyTorch+fastai.

This tutorial show how to implement this second option and you will find examples of use for text generation in the paragraph “Text Generation by our Portuguese GPT-2” at the end of this tutorial.

Why using fastai v2 over Hugging Face libraries to fine-tune a pre-trained transformer-based language model?

Tokenizers and Transformers from Hugging Face

The Tokenizers and Transformers library from Hugging Face (HF) are today the most up-to-date NLP libraries (Natural Language Processing) used all over the world (the libraries versions we used are from July 2020: transformers 3.0.0 and tokenizers 0.8.0).

Hugging Face

According to the HF official documentation, they were designed with two strong goals in mind:

. be as easy and fast to use as possible

. provide state-of-the-art models with performances as close as possible to the original models

However, as written in the Philosophy paragraph of the Quickstart HF page:

the Transformers library is NOT a modular toolbox of building blocks for neural nets. If you want to extend/build-upon the library, just use regular Python/PyTorch modules and inherit from the base classes of the library to reuse functionalities like model loading/saving.

Indeed, the reading of the new Hugging Face tutorials from june 2020 confirm that plain PyTorch must be used in order to train from scratch or fine-tune a pre-trained model in the Transformers library.

For example, the new Training and fine-tuning tutorial explains how Fine-tuning in native PyTorch. It is very helpful but how to apply 1cycle policy fine-tuning method for example? Or how to easily freeze or unfreeze some layers groups like in fastai v2 with the functions learn.unfreeze() and learn.freeze_to() instead of typing full PyTorch code?

fastai v2

Therefore, despite of the running py files published by Hugging Face (for example, the run_language_modeling.py for fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa)), when it comes necessary to fine-tune a pre-trained model to another language and/or to another task, we need to use easy fine-tuning methods over regular Python/PyTorch modules in order to apply Transfer Learning and fine-tuning modern techniques.

Since fastai v2 provides all of these powerful fine-tuning techniques, this is a primary candidate library for training transformer-based language models pre-trained with the Tokenizers and Transformers libraries of Hugging Face.

fastai v2
fastai v2

Here is a non-exhaustive list of the fastai v2 fine-tuning techniques based on Transfer Learning:

  • Learning rate finder (method that helps finding the best learning rate to train the model)
  • Mixed precision training (some of the operations will be done in FP16, others in FP32 in order to speed up the training)
  • Gradual unfreezing (layers groups are defined allowing to decide the layers to be trained)
  • 1cycle policy (the 1cycle policy was introduced by Leslie N. Smith et al. in Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates. It schedules the learning rate with a cosine annealing)
  • Differential learning rates (a specific learning rate is setup by layers group)
  • Distributed training (training distributed on different GPUs in order to speed up the training)

About the choice of GPT-2

In order to demonstrate the feasibility of fine-tuning Hugging Face models via fastai v2, we wanted to choose an emblematic model of the Transformer revolution in the NLP since 2017.

The original transformer model is made up of an encoder and decode (image credit: The illustrated GPT-2)
The original transformer model is made up of an encoder and decode (image credit: The illustrated GPT-2)

Thus, between the 2 historic transformer-based models GPT-2 and BERT models, we chose the GPT-2 model because it has strongly influenced minds beyond the circle of Deep Learning specialists in early 2019 by writing texts of a quality level close to that of humans. Today “exceeded” in number of parameters and performance by more recent models like BART, T5 and of course GPT-3 (175 billion parameters!), it remains a reference and a model used in research and applications.

(1/2) OpenAI GPT-2 is a transformer-based language model using only decoder blocks (image credit: The illustrated GPT-2)
(1/2) OpenAI GPT-2 is a transformer-based language model using only decoder blocks (image credit: The illustrated GPT-2)
(2/2) OpenAI GPT-2 is a transformer-based language model using only decoder blocks (image credit: The illustrated GPT-2)
(2/2) OpenAI GPT-2 is a transformer-based language model using only decoder blocks (note:we use an input sequence of 1024, not 4000 — image credit: The illustrated GPT-2)

Note: for those you want to understand better how GPT-2 works, read the following posts:

About the version of GPT-2: there are 3 versions of the GPT-2 model (look at the transformers documentation for more details). Here, we use the small version, the one with the smallest number of weights (124 millions, not 117 as written in the original paper) but you can change the model used by changing the content of pretrained_weights (if it's not a GPT-2 model, you'll need to change the classes used for the model and the tokenizer of course).

We used the English pre-trained GPT-2 small and its Byte-level BPE tokenizer (image credit: The illustrated GPT-2)
We used the English pre-trained GPT-2 small and its Byte-level BPE tokenizer in this tutorial (image credit: The illustrated GPT-2)

English pre-trained GPT-2 small

  • 12-layer, 768-hidden, 12-heads
  • 124M parameters, file of 548 Mo
  • Download time: about 10 minutes

English pre-trained Byte-level BPE tokenizer

Note: to understand better what is a Byte-level BPE tokenizer, read this post: Byte-level BPE, an universal tokenizer but…

Main coding steps to fine-tune a Hugging Face language model with fastai v2

We will find in the tutorial notebook the code, detailed explications and results about the 6 main coding steps to fine-tune a Hugging face language model with fastai v2. We copied/pasted in this post the key parts in order to focus on them.

1. Initialization
2. Download Wikipedia in Portuguese
3. Download a GPT-2 English pre-trained model and train a GPT-2 tokenizer with a vocab in Portuguese
3.1 Get the pre-trained GPT-2 Tokenizer & Model (pre-trained with an English corpus) from the Transformers library (Hugging Face)
3.2 Train a Byte-level BPE (BBPE) Tokenizer on the Portuguese Wikipedia corpus by using the Tokenizers library (Hugging Face)
3.3 Import the tokenizer Portuguese config files into the pre-trained GPT-2 Tokenizer
4. Create a fastai tokenizer and update the embedding matrix of the GPT-2 English pre-trained model
4.1 GPT2TokenizerFast (imported GPT2 tokenizer) --> fastai Tokenizer
4.2 Change vocab embedding in the GPT-2 pre-trained model to adapt to the Portuguese vocab
5. Create fastai v2 Datasets and Dataloaders
6. Fine-tuning the model
6.1 Splitter (get layers groups)
6.2 Learner
6.2.1 Freeze all layers but the last layers group (wte, wpe embedding matrices and last LayerNorm)
6.2.2 Freeze all layers but the last 2 layers groups
6.2.3 Freeze all layers but the last 3 layers groups
6.2.4 Unfreeze all layers

However, these 6 main steps can be summarized in 3 main ones:

The 3 main steps of fine-tuning the English GPT-2 to Portuguese with Hugging Face and fastai v2 (image edited — fast.ai NLP)
Fine-tuning the English GPT-2 to Portuguese with Hugging Face and fastai v2 in 3 main steps (image edited from fast.ai NLP)
  1. Initialization & download (download of Portuguese Wikipedia and GPT-2 English pre-trained model and tokenizer)
  2. GPT-2 tokenizer with a Portuguese vocab (train a GPT-2 tokenizer with a vocab in Portuguese, wrap it into a fastai v2 tokenizer and update the embedding matrix of the GPT-2 English pre-trained model according to the new Portuguese vocab: keep the embedding vectors of the common tokens between English and Portuguese vocabs)
  3. Fine-tune on Portuguese Wikipedia the GPT-2 model with fastai v2 training functionalities

Let’s start our journey to GPT-2 fine-tuned into Portuguese!

1. Initialization

# libraries installation
# fastai v2: read https://dev.fast.ai/#Installing
# tokenizers: !pip install tokenizers
# transformers: !pip install transformers
# import fastai v2
from fastai2.text.all import *
from nlputils_fastai2 import *
# import tokenizers and transformers
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
from tokenizers import ByteLevelBPETokenizer
# setup new path_data and create the lang folder
lang = 'pt'
name = f'{lang}wiki'
config = Config()
data_path = config['data_path']
path_data = data_path/name
path_data.mkdir(exist_ok=True, parents=True)

2. Download Wikipedia in Portuguese

In Wikimedia Downloads, you will find the dump of the Portuguese Wikipedia that has 1.037.991 articles at the date of the study (07/03/2020).

By selecting those with a minimum text length of 1.800, we downloaded 20% of these articles (204.315 files) which represent about 200 million words for a total size of 1.6 GB.

This dataset size has to be compared to the 40 GB of WebText (text extracted from Internet but Wikipedia) used by OpenAI to train from scratch English GPT-2 (see “About the English dataset used to train GPT-2” at the end of this paragraph).

We use 25 times less training data to obtain a GPT-2 in Portuguese than that used to obtain the GPT2 in English.

Note: all the following methods come from the file nlputils_fastai2.py from fastai. We did try to use as well the nlp library of Hugging Face to download the Portuguese Wikipedia but we faced an unsolved issue (see the notebook).

# download Portuguese Wikipedia
get_wiki(path_data,lang)
# create one text file by article
dest = split_wiki(path_data,lang)
# get all articles in one text file and one csv file
get_one_clean_file(dest,lang)
get_one_clean_csv_file(dest,lang)

Note: the text file (all the articles in one file) will allow the training of the Portuguese tokenizer and the csv one will facilitate the tests of the study.

First articles from downloaded Portuguese Wikipedia

About the English dataset used to train GPT-2

(source) The resulting dataset, WebText, contains the text subset of these 45 million links. To extract the text from HTML responses we use a combination of the Dragnet (Peters &Lecocq, 2013) and Newspaper content extractors. All results presented in this paper use a preliminary version of WebText which does not include links created after Dec 2017 and which after de-duplication and some heuristic based cleaning contains slightly over 8 million documents for a total of 40 GB of text. We removed all Wikipedia documents from WebText since it is a common data source for other datasets and could complicate analysis due to overlapping training data with test evaluation tasks.

3. Download a GPT-2 English pre-trained model and train a GPT-2 tokenizer with a vocab in Portuguese

We are following 3 steps in order to get a GPT-2 tokenizer with a vocab in Portuguese:

  1. Get the pre-trained GPT-2 Tokenizer & Model (pre-trained with an English corpus) from the Transformers library (Hugging Face): it will give us the tokenizer structure we need and the pre-trained model weights (it’s better to start training our GPT-2 model in Portuguese from weights already trained even in another language than from random values)
  2. Train a Byte-level BPE (BBPE) Tokenizer on the Portuguese Wikipedia corpus by using the Tokenizers library (Hugging Face): this will give us the vocabulary files in Portuguese of our GPT-2 tokenizer.
  3. Import the tokenizer Portuguese config files (vocab.json, merges.txt) into the pre-trained GPT-2 Tokenizer: it will give us a GPT-2 tokenizer structure with the vocab in Portuguese.

One relevant point is that we trained our Portuguese Byte-level BPE tokenizer on Portuguese Wikipedia (here, 1.6 GB) in only 2min 7s. Thanks Hugging Face!

# 1. Get the pre-trained GPT2 Tokenizer (pre-trained with an English
# corpus) from the Transformers library (Hugging Face)
from tokenizers import ByteLevelBPETokenizer
pretrained_weights = 'gpt2'
tokenizer_en = GPT2TokenizerFast.from_pretrained(pretrained_weights)
tokenizer_en.pad_token = tokenizer_en.eos_token
# 2. Train a Byte Level BPE (BBPE) tokenizer on the Portuguese
# Wikipedia corpus by using the Tokenizers library (Hugging Face)
# 2.1 Get GPT2 tokenizer_en vocab size
ByteLevelBPE_tokenizer_pt_vocab_size = tokenizer_en.vocab_size
ByteLevelBPE_tokenizer_pt_vocab_size
# 2.2 ByteLevelBPETokenizer Represents a Byte-level BPE
# as introduced by OpenAI with their GPT-2 model
from tokenizers import ByteLevelBPETokenizer
ByteLevelBPE_tokenizer_pt = ByteLevelBPETokenizer()
# 2.3 Get list of paths to corpus files
# and customize training with <|endoftext|> special GPT-2 token
paths = [str(path_data/'all_texts_ptwiki.txt')]
ByteLevelBPE_tokenizer_pt.train(files=paths,
vocab_size=ByteLevelBPE_tokenizer_pt_vocab_size,
min_frequency=2,
special_tokens=["<|endoftext|>"])
# Get sequence length max of 1024
ByteLevelBPE_tokenizer_pt.enable_truncation(max_length=1024)
# 2.4 save tokenizer
ByteLevelBPE_tokenizer_pt_rep = 'ByteLevelBPE_tokenizer_pt'
path_to_ByteLevelBPE_tokenizer_pt_rep = path_data/ByteLevelBPE_tokenizer_pt_rep
if not (path_to_ByteLevelBPE_tokenizer_pt_rep).exists():
path_to_ByteLevelBPE_tokenizer_pt_rep.mkdir(exist_ok=True, parents=True)
ByteLevelBPE_tokenizer_pt.save_model(str(path_to_ByteLevelBPE_tokenizer_pt_rep))
# 3. Import the tokenizer config files in Portuguese into the pre-trained GPT2 Tokenizer
tokenizer_pt = GPT2TokenizerFast.from_pretrained(
str(path_to_ByteLevelBPE_tokenizer_pt_rep),
pad_token='<|endoftext|>')
# Get sequence length max of 1024
tokenizer_pt.model_max_length = 1024

4. Create a fastai tokenizer and update the embedding matrix of the GPT-2 English pre-trained model

Now let’s see how we can use fastai v2 to fine-tune this model on Wikipedia in Portuguese, using all the fastai v2 training and fine-tuning utilities.

We will follow these 2 following steps:

  1. GPT2TokenizerFast (imported GPT-2 tokenizer) → fastai Tokenizer: to process the data to train a model, we need to build a fastai tokenizer from the GPT-2 tokenizer with vocab in Portuguese.
  2. Change vocab embedding (wte matrix) in the GPT-2 pre-trained model to adapt to the Portuguese vocab: as the vocab embedding matrix (wte) of the pre-trained GPT-2 model corresponds to the English vocabulary, we’ll keep the embedding vectors of the common tokens between the English and Portuguese vocab.

(text from Sylvain Gugger Transformers Tutorial) To process this data to train a model, we need to build a Transform that will be applied lazily. In a fastai Transform you can define:

  • an encodes method that is applied when you call the transform (a bit like the forward method in a nn.Module)
  • a decodes method that is applied when you call the decode method of the transform, if you need to decode anything for showing purposes (like converting ids to a text here)
  • a setups method that sets some inner state of the Transform (not needed here)
# 1. GPT2TokenizerFast (imported GPT-2 tokenizer) → fastai Tokenizer
class TransformersTokenizer(Transform):
def __init__(self, tokenizer): self.tokenizer = tokenizer
def encodes(self, x):
toks = self.tokenizer.tokenize(x)
return tensor(self.tokenizer.convert_tokens_to_ids(toks))
def decodes(self, x): return TitledStr(self.tokenizer.decode(x.cpu().numpy()))
tokenizer_fastai_en = TransformersTokenizer(tokenizer_en)
tokenizer_fastai_pt = TransformersTokenizer(tokenizer_pt)
# 2. Change vocab embedding in the GPT-2 pre-trained model to adapt to the Portuguese vocab
# Get weights of the old wte
old_wgts = model.transformer.get_input_embeddings().weight.clone().detach()
# Get the mean embedding vector of the old wte
wgts_m = old_wgts.mean(0)
# Initialize vocab size and weights of the new wte
new_vocab_size = tokenizer_fastai_pt.tokenizer.vocab_size
new_wgts = old_wgts.new_zeros(new_vocab_size,old_wgts.size(1))
# Get the new wte keeping the embedding vectors of tokens
# in common in the 2 vocabs
# A token present in the new vocab but not in the old one
# gets the mean embedding vector of the old wte
old_vocab = tokenizer_fastai_en.tokenizer.get_vocab()
new_vocab = tokenizer_fastai_pt.tokenizer.get_vocab()
same_tokens_list = list()
different_tokens_list = list()

for w,idx_new in new_vocab.items():
idx_old = old_vocab.get(w, -1)
if idx_old>=0:
new_wgts[idx_new] = old_wgts[idx_old]
same_tokens_list.append((w,idx_new))
else:
new_wgts[idx_new] = wgts_m
different_tokens_list.append((w,idx_new))
# setup in model the new wte
new_wte = nn.Embedding(new_vocab_size,old_wgts.size(1))
new_wte.weight.data = new_wgts
model.transformer.set_input_embeddings(new_wte)
# save new_wgts
torch.save(new_wgts, path_data/'new_wte_wgts.pt')
# save same_tokens_list and different_tokens_list
torch.save(same_tokens_list, path_data/'same_tokens_list.pt')
torch.save(different_tokens_list, path_data/'different_tokens_list.pt')
# Changing lm_head weights with the new embedding
matrixmodel.lm_head.weight = model.transformer.wte.weight

Portuguese embedding wte matrix setup done!

We kept 12.948 embedding vectors from the English one (~25%).
We did not kept 37.309 embedding vectors (~75%) from the English one (instead, we used the old wte mean vector).

  • 15 first tokens IN common between the 2 vocabs:
    [(‘ĠQuit’, 40195), (‘Smith’, 32470), (‘Ġomit’, 39040), (‘oc’, 574), (‘ym’, 18252), (‘Ġactual’, 9443), (‘ck’, 911), (‘ĠPremier’, 16558), (‘Ġeste’, 987), (‘ĠInd’, 3438), (‘Ġbol’, 4203), (‘phen’, 35836), (‘ĠParticip’, 36689), (‘ĠZeus’, 19316), (‘Ġnan’, 39770)]
  • 15 first Portuguese tokens NOT in common between the 2 vocabs:
    [(‘PSDB’, 23151), (‘Ġenvio’, 19270), (‘Ġocupação’, 5938), (‘Ġdocumentada’, 30011), (‘Ġduros’, 36706), (‘visto’, 44422), (‘ĠSiro’, 43061), (‘Ġdestacavam’, 47397), (‘Ġarqui’, 49060), (‘ĠArte’, 5977), (‘ĠValor’, 29721), (‘Ġalinhados’, 38446), (‘Ġnúmeros’, 4626), (‘Ġpênis’, 31686), (‘cisa’, 29710)]

5. Create fastai v2 Datasets and Dataloaders

(text from Sylvain Gugger Transformers Tutorial) You can then group your data with this Transform using a TfmdLists. It has an s in its name because it contains the training and validation datasets.

We indicate the indices of the training dataset and the validation dataset with splits (here, 80% of the indices randomly chosen, then all the remaining indices).

We specify dl_type=LMDataLoader in the TfmdListsfor when we will convert this TfmdLists to DataLoaders: we will use an LMDataLoader since we have a language modeling problem, not the usual fastai TfmdDL.

# train = 80%
# validation = 20%
num = int(0.8*len(df))
idxs = np.random.randint(0, len(df), len(df))
idxs_train = idxs[:num]
idxs_val = idxs[num:]
# We gather all texts in one numpy array
# (since it will be easier to use this way with fastai)
all_texts = np.concatenate([df.iloc[idxs_train].text.values, df.iloc[idxs_val].text.values])
splits = [list(idxs_train), list(idxs_val)]
tls = TfmdLists(all_texts, TransformersTokenizer(tokenizer_pt), splits=splits, dl_type=LMDataLoader)

(text from Sylvain Gugger Transformers Tutorial) The fastai v2 library expects the data to be assembled in a DataLoaders object (something that has a training and validation dataloader). We can get one by using the dataloaders method. We just have to specify a batch size and a sequence length:

  • Let’s use a batch size of 8 (a value higher gives a “CUDA out of memory error” on our single GPU).
  • Since the GPT-2 model was trained with sequences of size 1024, we use this sequence length (it’s a stateless model, so it will change the perplexity if we use less).
bs,sl = 8,1024
dls = tls.dataloaders(bs=bs, seq_len=sl)

6. Fine-tuning the model

(text from Sylvain Gugger Transformers Tutorial) The Hugging Face model will return a tuple in outputs, with the actual predictions and some additional activations (should we want to use them is some regularization scheme). To work inside the fastai v2 training loop, we will need to drop those using a Callback: we use those to alter the behavior of the training loop.

Here we need to write the event after_pred and replace self.learn.pred (which contains the predictions that will be passed to the loss function) by just its first element. In callbacks, there is a shortcut that lets you access any of the underlying Learner attribute so we can write self.pred[0] instead of self.learn.pred[0]. That shorcut only works for read access, not write, so we have to write self.learn.pred on the right side (otherwise we would set a pred attribute in the Callback).

class DropOutput(Callback):
def after_pred(self): self.learn.pred = self.pred[0]

6.1 Splitter (get the layers groups)

The model has 2 main layers groups (ou parameters groups): transformer and lm_head. As we can read in The illustrated GPT2, the lm_head is a copy of the vocab embedding matrix wte in order to get after the softmax probability of each token in the vocab. Therefore, we need to split only the transformer layers group to get all layers.

transformer

  • (wte) vocab embedding (vocab tokens → embedding)
  • (wpe) positionning embedding (tokens positions in input sequence → embedding)
  • 12 decoder blocks (attention heads)

lm_head

  • LayerNorm

Now, we can create our layers groups that will allow us to use all the fastai v2 fine-tuning techniques. Moreover, we decided to follow the fine-tuning method showed for text classification training in the notebook 10_nlp.ipynb by creating 4 layers groups: 3 layers groups of 4 decoder blocks each and one embedding groups with the wte and wpe matrices.

def splitter(model):
"Split a GPT2 `model` in 3 groups for differential learning rates."

# First layers group : decoder blocks from 0 to 3
modules = []
for i in range(4): modules.append(model.transformer.h[i])
groups = [nn.Sequential(*modules)]
# Second layers group : decoder blocks from 4 to 7
modules = []
for i in range(4,8,1): modules.append(model.transformer.h[i])
groups = L(groups + [nn.Sequential(*modules)])
# Third layers group : decoder blocks from 8 to 11
modules = []
for i in range(8,12,1): modules.append(model.transformer.h[i])
groups = L(groups + [nn.Sequential(*modules)])

# Fourth layers group : embedding matrices wte and wpe
# + LayerNorm at the model output
groups = L(groups + [nn.Sequential(model.transformer.wte,model.transformer.wpe,model.transformer.ln_f)])

return groups.map(params)

6.2 Learner

(text from Sylvain Gugger Transformers Tutorial) Now, we are ready to create our Learner, which is a fastai object grouping data, model and loss function and handles model training or inference. Since we are in a language model setting, we pass accuracy and perplexity as metrics, and we need to use the callback we just defined. Lastly, we use mixed precision to save every bit of memory we can (and if you have a modern GPU, it will also make training faster).

learn = Learner(dls, model, loss_func=CrossEntropyLossFlat(),
splitter = splitter,
cbs=[DropOutput],
metrics=[accuracy, Perplexity()]).to_fp16()

We can check how good the model is without any fine-tuning step by running learn.validate(). In 53min 2s, we got:

  • validation loss: 9.949938774108887
  • validation accuracy: 0.09898579120635986
  • validation perplexity: 20950.939453125

Not so bad nearly 10% of accuracy without any fine-tuning! It means we start our journey to GPT-2 in Portuguese with a language model that already has a strong knowledge of the language rules (weights) and a basic one of Portuguese (25% of its vocab embedding matrix).

Now that we have a Learner, we will use during training all the fastai v2 fine-tuning techniques seen for text classification training (see the notebook 10_nlp.ipynb about "NLP Deep Dive: RNNs") to take advantage of the Transfer Learning of the GPT-2 pre-trained embedding matrices and model from Hugging Face Transformers:

  • Learning rate finder (method that helps finding the best learning rate to train the model)
  • Mixed precision training (some of the operations will be done in FP16, others in FP32 in order to speed up the training)
  • Gradual unfreezing (the model has 4 layers groups created by our method splitter : the embedding one and the 3 groups of 4 decoder blocks each)
  • 1cycle policy with the method fit_one_cycle() (The 1cycle policy was introduced by Leslie N. Smith et al. in Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates. It schedules the learning rate with a cosine annealing from lr_max/div to lr_max then lr_max/div_final (pass an array to lr_max if you want to use differential learning rates) and the momentum with cosine annealing according to the values in moms. The first phase takes pct_start of the training. You can optionally pass additional cbs and reset_opt.)
  • Differential learning rates (each layers group with a learning rate different: the biggest one for the embedding group, and the smallest one for the first 4 decoder blocks)

6.2.1 Freeze all layers but the last layers group (do not freeze wte, wpe embedding matrices and last LayerNorm)

learn.freeze()
learn.summary()
GPT2LMHeadModel (Input shape: ['8 x 1024'])
================================================================
Layer (type) Output Shape Param # Trainable
================================================================
Embedding 8 x 1024 x 768 38,597,376 True
________________________________________________________________
Embedding 8 x 1024 x 768 786,432 True
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 2304 1,771,776 False
________________________________________________________________
Conv1D 8 x 1024 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 1024 x 102 0 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 3072 2,362,368 False
________________________________________________________________
Conv1D 8 x 1024 x 768 2,360,064 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 2304 1,771,776 False
________________________________________________________________
Conv1D 8 x 1024 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 1024 x 102 0 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 3072 2,362,368 False
________________________________________________________________
Conv1D 8 x 1024 x 768 2,360,064 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 2304 1,771,776 False
________________________________________________________________
Conv1D 8 x 1024 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 1024 x 102 0 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 3072 2,362,368 False
________________________________________________________________
Conv1D 8 x 1024 x 768 2,360,064 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 2304 1,771,776 False
________________________________________________________________
Conv1D 8 x 1024 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 1024 x 102 0 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 3072 2,362,368 False
________________________________________________________________
Conv1D 8 x 1024 x 768 2,360,064 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 2304 1,771,776 False
________________________________________________________________
Conv1D 8 x 1024 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 1024 x 102 0 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 3072 2,362,368 False
________________________________________________________________
Conv1D 8 x 1024 x 768 2,360,064 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 2304 1,771,776 False
________________________________________________________________
Conv1D 8 x 1024 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 1024 x 102 0 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 3072 2,362,368 False
________________________________________________________________
Conv1D 8 x 1024 x 768 2,360,064 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 2304 1,771,776 False
________________________________________________________________
Conv1D 8 x 1024 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 1024 x 102 0 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 3072 2,362,368 False
________________________________________________________________
Conv1D 8 x 1024 x 768 2,360,064 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 2304 1,771,776 False
________________________________________________________________
Conv1D 8 x 1024 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 1024 x 102 0 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 3072 2,362,368 False
________________________________________________________________
Conv1D 8 x 1024 x 768 2,360,064 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 2304 1,771,776 False
________________________________________________________________
Conv1D 8 x 1024 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 1024 x 102 0 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 3072 2,362,368 False
________________________________________________________________
Conv1D 8 x 1024 x 768 2,360,064 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 2304 1,771,776 False
________________________________________________________________
Conv1D 8 x 1024 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 1024 x 102 0 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 3072 2,362,368 False
________________________________________________________________
Conv1D 8 x 1024 x 768 2,360,064 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 2304 1,771,776 False
________________________________________________________________
Conv1D 8 x 1024 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 1024 x 102 0 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 3072 2,362,368 False
________________________________________________________________
Conv1D 8 x 1024 x 768 2,360,064 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 2304 1,771,776 False
________________________________________________________________
Conv1D 8 x 1024 x 768 590,592 False
________________________________________________________________
Dropout 8 x 12 x 1024 x 102 0 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 False
________________________________________________________________
Conv1D 8 x 1024 x 3072 2,362,368 False
________________________________________________________________
Conv1D 8 x 1024 x 768 2,360,064 False
________________________________________________________________
Dropout 8 x 1024 x 768 0 False
________________________________________________________________
LayerNorm 8 x 1024 x 768 1,536 True
________________________________________________________________
Linear 8 x 1024 x 50257 38,597,376 True
________________________________________________________________

Total params: 163,037,184
Total trainable params: 77,982,720
Total non-trainable params: 85,054,464

Optimizer used: <function Adam at 0x7fce2f8dae60>
Loss function: FlattenedLoss of CrossEntropyLoss()

Model frozen up to parameter group number 3

Callbacks:
- DropOutput
- ModelToHalf
- TrainEvalCallback
- Recorder
- ProgressCallback
- MixedPrecision

The learn.summary() method gives almost the right numbers. In fact, it counts twice the weights of the wte matrix (vocab embedding matrix) because they are duplicated in the weights of the output linear layer.

The real numbers are:

  • Total params: 163,037,184–38,597,376 = 124,439,808 (about 124 millions)
  • Total trainable params: 77,982,720–38,597,376 = 39,385,344 (about 40 millions)
  • Total non-trainable params: 85,054,464 (about 85 millions)

Now, let’s choose the best learning rate to launch the fine-tuning of the Portuguese GPT-2 thanks to the fastai v2 learning rate finder.

learn.lr_find()
Results from learn.lr_find() before starting training the Portuguese GPT-2
Results from learn.lr_find() before starting training the Portuguese GPT-2

The learning rate finder curve suggests a learning rate mininum of 6e-3. Let’s use 2e-3 which seems to give the highest decrease in validation loss according to the previous graph.

learn.freeze()
learn.fit_one_cycle(1, 2e-3)
epoch 0
train_loss 3.803344
valid_loss 3.640777
accuracy 0.325177
perplexity 38.121441
time 5:48:31

In just one epoch, our model passed

  • from an accuracy of 9.90% to 32.52%
  • from a perplexity of 20950.94 to 38.12

Not too bad!

We can trace the training and validation loss curves thanks to the fastai v2 loss plotting function in order to visually verify the strong improvement of our model (i.e. the strong reduction in training and validation losses).

learn.recorder.plot_loss()
Evolution of training and validation losses during the first fine-tuning epoch of the Portuguese GPT-2
Evolution of training and validation losses during the first fine-tuning epoch of the Portuguese GPT-2

Now, we can pass -2 to freeze_to to freeze all except the last two layers groups (learn.unfreeze() = learn.freeze_to(-1)).

6.2.2 Freeze all layers but the last 2 layers groups

learn.freeze_to(-2)
learn.summary()

Again, the learn.summary () method gives almost the right numbers. In fact, it counts twice the weights of the wte matrix (vocab embedding matrix) because they are duplicated in the weights of the output linear layer.

The real numbers are:

  • Total params: 163,037,184–38,597,376 = 124,439,808 (about 124 millions)
  • Total trainable params: 106,334,208–38,597,376 = 67,736,832 (about 68 millions)
  • Total non-trainable params: 56,702,976 (about 57 millions)
learn.freeze_to(-2)
learn.fit_one_cycle(1, slice(1e-3/(2.6**4),1e-3))
train_loss 3.453913
valid_loss 3.301886
accuracy 0.362879
perplexity 27.163816
time 5:38:18

Good! Our model goes on learning. It went

  • from an accuracy of 32.52% to 36.29%
  • from a perplexity of 38.12 to 27.16

We can plot the training and validations losses curves.

learn.recorder.plot_loss()
Evolution of training and validation losses during the second fine-tuning epoch of the Portuguese GPT-2
Evolution of training and validation losses during the second fine-tuning epoch of the Portuguese GPT-2

Let’s go one by passing -3 to freeze_to to freeze all except the last three layers groups.

6.2.3 Freeze all layers but the last 3 layers groups

learn.freeze_to(-3)
learn.summary()

The learn.summary() method gives almost the right numbers. In fact, it counts twice the weights of the wte matrix (vocab embedding matrix) because they are duplicated in the weights of the output linear layer.

The real numbers are:

  • Total params: 163,037,184–38,597,376 = 124,439,808 (about 124 millions)
  • Total trainable params: 134,685,696–38,597,376 = 96,088,320 (about 96 millions)
  • Total non-trainable params: 28,351,488 (about 28 millions)
learn.freeze_to(-3)
learn.fit_one_cycle(1, slice(5e-4/(2.6**4),5e-4))
train_loss 3.333389
valid_loss 3.207390
accuracy 0.374579
perplexity 24.714487
time 6:20:51

Yeap! Our model (still) goes on learning: it passed

  • from an accuracy of 36.29% to 37.46%
  • from a perplexity of 27.16 to 24.71

We can plot the training and validation losses curves.

learn.recorder.plot_loss()
Evolution of training and validation losses during the third fine-tuning epoch of the Portuguese GPT-2
Evolution of training and validation losses during the third fine-tuning epoch of the Portuguese GPT-2

Let’s finish our work one by unfreezing all layers groups, which means all parameters of the Portuguese GPT-2 model.

6.2.4 Unfreeze all layers

learn.unfreeze()
learn.summary()

One more time, the learn.summary() method gives almost the right numbers. In fact, it counts twice the weights of the wte matrix (vocab embedding matrix) because they are duplicated in the weights of the output linear layer.

The real numbers are:

  • Total params: 163,037,184–38,597,376 = 124,439,808 (about 124 millions)
  • Total trainable params: 163,037,184–38,597,376 = 124,439,808 (about 124 millions)
  • Total non-trainable params: 0
learn.unfreeze()
learn.fit_one_cycle(2, slice(1e-4/(2.6**4),1e-4))
epoch 0
train_loss 3.288433
valid_loss 3.186721
accuracy 0.377380
perplexity 24.208906
time 6:06:29
epoch 1
train_loss 3.232569
valid_loss 3.167864
accuracy 0.379885
perplexity 23.756687
time 6:16:22

GPUuuuuuuu! Our model (a bit but still) goes on learning: it went

  • from an accuracy of 37.46% to 37.99%
  • from a perplexity of 24.71 to 23.76

We can plot the training and validation losses curves.

learn.recorder.plot_loss()
Training and validation loss evolution during the fourth and fifth epoch

Following the fastai v2 text classification fine tuning strategy and due to our very good results (37.99% accuracy and 23.76 perplexity), we decided to stop fine-tuning the Portuguese GPT-2 at the end of these 5 epochs.

Model sharing and uploading in the Hugging Face model hub

Let’s see now how we can share our Portuguese GPT-2 on the Hugging Face model hub (source: Model sharing and uploading). You will find all the code corresponding to our tokenizer and model in the tutorial notebook.

Thus, our model now has a page on huggingface.co/models 🔥

Anyone can load it from the following code:

from transformers import AutoTokenizer, AutoModelWithLMHeadtokenizer = AutoTokenizer.from_pretrained("pierreguillou/gpt2-small-portuguese")
model = AutoModelWithLMHead.from_pretrained("pierreguillou/gpt2-small-portuguese")

Check our Hugging face model page to get more information.

Text Generation by our Portuguese GPT-2

Now that we have a GPT-2 in Portuguese, we can use it for different tasks in NLP (Text Generation, Reading Comprehension, Translation, Summary) as showed in the post “GPT-2 use cases: beyond Text Generation”.

For now, let’s use it to generate new texts, which allows us to check that it works properly and also have a little fun.

Text Generation techniques

At each stage of text generation, GPT-2 provides a vector of 50.257 probabilities (each corresponds to a possible token of the vocabulary whose size is 50.257). To decide how to choose the output token from these probabilities, there are at least 5 methods: Greedy, Beam Search, Sampling with temperature, Top-k sampling and Top-p (nucleus) sampling.

In this tutorial, we will test only 2 of these text generation methods: Top-k sampling and Top-p (nucleus) sampling.

Note: to get more information on text generation techniques for transformer-based language model, read the article “How to generate text: using different decoding methods for language generation with Transformers from Patrick von Platen” (Hugging Face, 03/18/2020).

(Use case 1) Top-k sampling

Our use case 1 follows the same method used by OpenAI in page 20 of the paper Language Models are Unsupervised Multitask Learners by choosing Top-k sampling text generation technique with a value of 40.

This text generation method is implemented in the model.generate() function of a Transformers model thanks to the following argument:

  • top_k (int): the number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.

(Use case 2) Top-p (nucleus)

Our use case 2 follows the top-p (nucleus) sampling method with Top-p sampling (top_p = 0.95), top-k sampling (top_k = 50), temperature (temperature = 0.7) and repetition penalty (repetition_penalty = 1.2).

This text generation method is implemented in the model.generate() function of a Transformers model thanks to the following argument:

  • top_p (float): the cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
  • top_k (int): the number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
  • temperature (float): the value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
  • repetition_penalty (float): the parameter for repetition penalty. Between 1.0 and infinity. 1.0 means no penalty. Default to 1.0.

Text n°1 | Famous OpenAI generated text about unicorns

At the time of publication of GPT-2 in the article “Better Language Models and Their Implications” (02/14/2019), the media retained from its different possibilities in NLP that of text generation because of the now famous text generated on unicorns from this small paragraph: “In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.

Famous text on unicorns generated by English GPT-2 from OpenAI (source: page 20 from: LM are Unsupervised Multitask Learners)
Famous text on unicorns generated by English GPT-2 from OpenAI (sources: sample 1 and page 20 from “Language Models are Unsupervised Multitask Learners”)

Get translated famous unicorn text in Portuguese

By using the MarianMT translator English to Portuguese that is available in the Transformers library of Hugging Face, we’ve got the Portuguese version of this text: Em um achado chocante, o cientista descobriu um rebanho de unicórnios vivendo em um vale remoto, anteriormente inexplorado, nas Montanhas dos Andes. Ainda mais surpreendente para os pesquisadores foi o fato de que os unicórnios falavam inglês perfeito.

from transformers import MarianMTModel, MarianTokenizersrc_text = [
'>>pt_BR<< In a shocking finding, scientist discovered a herd of unicorns living in a remote, previously unexplored valley, in the Andes Mountains. Even more surprising to the researchers was the fact that the unicorns spoke perfect English.',
]
model_name = 'Helsinki-NLP/opus-mt-en-ROMANCE'
tokenizer_en_pt = MarianTokenizer.from_pretrained(model_name)
print(tokenizer.supported_language_codes)
model_en_pt = MarianMTModel.from_pretrained(model_name)
translated = model_en_pt.generate(**tokenizer_en_pt.prepare_translation_batch(src_text))
tgt_text = [tokenizer_en_pt.decode(t, skip_special_tokens=True) for t in translated]

Get generated text

Use case 1 (Top-k sampling)

The code is:

#set top_k = 40 and num_return_sequences = 3
sample_outputs = model_pt.generate(input_ids, pad_token_id=50256,
do_sample=True,
max_length=max_length,
min_length=max_length,
top_k=40,
num_return_sequences=3)
for i, sample_output in enumerate(sample_outputs):
print(">> Generated text {}\n\n{}".format(i+1, tokenizer_pt.decode(sample_output.tolist())))
print('\n---')

The best text among the 3 generated is:

Num achado chocante, o cientista descobriu uma manada de unicórnios vivendo num vale remoto, anteriormente inexplorado, nas Montanhas dos Andes. Ainda mais surpreendente para os pesquisadores foi o fato de que os unicórnios falavam inglês perfeito. "Não é mais estranho que a nossa forma tivesse o inglês com dois de suas asas como se o macho fosse inglês — o que é interessante. Mas a sua natureza inata seria estranha para o inglês", acredita eles.

Em 2015, cientistas realizaram uma nova análise sobre as formas dos unicórnios. De acordo com especialistas na área, os membros superiores do grupo foram provavelmente derivados de outra espécie de escorpião — uma espécie com características semelhantes. Uma nova equipe de cientistas calculou que uma fêmea unicornada da Eurásia seria originalmente uma humana. "Isto significa que o ancestral do unicórnio, um híbrido de um esquilo e um escorpião macho não nasceu.

Um estudo recente estimou que cerca de 12% do corpo humano é composto por membros de qualquer um dos grupos mais diverso de animais extintos, incluindo o ser humano e o unicórnio-do-sul. "A análise dos dados mostra que a maioria dos membros do gênero é composta por um exito e um exito macho que compartilham uma única espécie de corpo. Em alguns casos, estes membros compartilham um mesmo conjunto (a linhagem) de partes em dois espécimes." O estudo indica que os membros de "P. rubi" são semelhantes em aparência e morfologia aos membros humanos modernos, como as fêmeas modernas e machos robustos. "Como é evidente com os membros de "P. rubi", os ancestrais e o ancestral eram similares na forma e na composição das semelhanças em um organismo."

O DNA do "P. rubi," chamado por sua forma em inglês de "sonoroplasto", revela que o "sonoroplasto" inclui três genes de alto nível e quatro genes relativamente reduzidos (e ausentes) e um gene de baixo nível (e ausentes) com uma concentração de cloroplasto em cada núcleo. A "sonoroplasto" se assemelha à "P. rubi" em características morfológicas e comportamentais, embora as diferenças na morfologia sejam menores. "A espécie "P. rubi" apresenta cinco pares de cromossomos separados (com 6 pares se aproximando e 8 pares se afastando) e um "sonoroplasto de base" (com 12 pares se aproximando e 15 pares se afastando), sugerindo que o membro tenha uma composição semelhante ao ancestral "P. rubi".

Use case 2 (Top-p nucleus sampling)

The code is:

#set top_p = 0.95, top_k = 50, temperature = 0.7, repetition_penalty = 1.2 and num_return_sequences = 3sample_outputs = model_pt.generate(input_ids, pad_token_id=50256,
do_sample=True,
max_length=max_length,
min_length=max_length,
repetition_penalty=1.2,
temperature=0.7,
top_k=50,
top_p=0.95,
num_return_sequences=3)
for i, sample_output in enumerate(sample_outputs):
print(">> Generated text {}\n\n{}".format(i+1, tokenizer_pt.decode(sample_output.tolist())))
print('\n---')

The best text among the 3 generated is:

Num achado chocante, o cientista descobriu uma manada de unicórnios vivendo num vale remoto, anteriormente inexplorado, nas Montanhas dos Andes. Ainda mais surpreendente para os pesquisadores foi o fato de que os unicórnios falavam inglês perfeito. Eles não sabiam onde exatamente eram falantes nativos do idioma, e acreditaram que eles simplesmente migraram das terras altas da região de Mendoza ao norte como consequência do declínio populacional que ocorreu na Cordilheira das Cobras.

Em 2004, o Departamento de Antropologia da Universidade do Colorado anunciou que havia encontrado uma fêmea no vale do rio Orinoco na Bolívia, mas essa fêmea foi morta durante a investigação. No entanto, no início de 2006, as autoridades locais anunciaram que havia identificado uma fêmea encontrada em uma área próxima à Cordilheira dos Andes, no Vale do Cauca. A equipe de pesquisadores relatou que esta fêmea era chamada de "El Maria" ou "El Maria".

O estudo revelou que o grupo de unicórnios habitava um ecossistema bastante diverso, com espécies endêmicas incluindo espécies como as tiláceas gigantescas (que são encontradas principalmente nos países subdesenvolvidos) e as quelupus ("Erica azoricae").

Um dos principais objetivos do estudo da espécie é determinar se os europeus teriam colonizado a região entre a década de 1940 e 1960 e se estes últimos grupos étnicos sobreviveram até hoje. Os cientistas acreditam que as populações desses grupos poderiam ter sido muito maiores antes disso; por exemplo, a teoria sugere que a população europeia provavelmente teria introduzido os humanos primitivos na América Central depois que os espanhóis invadiram a região, embora isso seja controverso.

O gênero "El Maria" tem um ancestral comum, os "Looney-do-the-Bone", um pequeno grupo de "Looney-da-Daíndia" encontrados apenas no leste dos Estados Unidos, Canadá e México. O gênero possui parentesco próximo ao gênero "Lontrapyrus", também conhecido como lontras negras. Acredita-se que esses indivíduos tenham migrado para o leste dos Andes, atravessando regiões montanhosas do sul de América Central e América Central.

Os membros desta família são geralmente confundidos com os lontras brancos.

As fêmeas têm cerca de seis centímetros de comprimento, pesando de 9 quilogramas e medindo 11 cm de largura. A cabeça é branca, com manchas escuras pretas escuras sobre seus flancos. As patas posteriores podem ser amarelas, enquanto sua cauda pode estar preta ou branca, dependendo da cor utilizada na identificação. As costas apresentam quatro dedos dorsais bem desenvolvidas.

Text n°2 | Recent text on the coronavirus disease (Covid-19)

Among all the links presented by Google News with the keyword covid-19, we have selected that of the article “Vacina contra coronavírus feita pela Rússia entra em ultima fase de testes” (uol, 07/13/2020 ) and copied/pasted the first paragraph as input for our GPorTuguese-2 model:

A Rússia está mais perto de se tornar o primeiro país a iniciar a distribuição de uma vacina contra o coronavírus para a população. O país anunciou hoje que concluiu parte dos testes clínicos necessários para comprovar a eficácia da imunização desenvolvida por iniciativa do governo russo. A expectativa é de que a distribuição comece já em agosto.

Get generated text

Use case 1 (Top-k sampling)

The best text among the 3 generated is:

A Rússia está mais perto de se tornar o primeiro país a iniciar a distribuição de uma vacina contra o coronavírus para a população. O país anunciou hoje que concluiu parte dos testes clínicos necessários para comprovar a eficácia da imunização desenvolvida por iniciativa do governo russo. A expectativa é de que a distribuição comece já em agosto.

Entre os primeiros casos confirmados de COVID-19 na Rússia estava um homem russo com idade entre 50 e 49 anos, segundo o jornal " Moscow" e o primeiro caso em 28 de fevereiro, a primeira no Hospital Pulkai.

No entanto, os primeiros casos foram mais sérios devido sua "maturidade sem gravidade" e a falta de uma pessoa estar disponível para testes de suas condições de vida. Entre os pacientes que foram considerados estão funcionários de um hospital ou enfermeiros na cidade de Moscou (ver abaixo) ou médicos. Depois de serem testados negativos após um teste positivo, o paciente se recupera completamente.

Os russos também anunciaram que serão realizados testes de coronavírus de outros países, como a França, que inicialmente acreditava que os vírus da gripe tinham sido transmitido pela Europa (a expectativa é de 20 a 50 casos por dia). A situação foi resolvida em 24 de fevereiro, quando o Ministério da Saúde confirmou sua conclusão de que a COVID-19 é transmitida de via aérea.

Até ao dia do seu primeiro caso, a Rússia tinha o menor número de funcionários e médicos com doença grave antes de o vírus ter se tornado um vírus no país. Os números de funcionários com doença grave não chegaram a ser confirmados, segundo o Ministério da Saúde, devido a sua falta de apoio.

O vírus que está em curso no país é transmitido pela primeira vez nos Estados Unidos, onde foi isolado em 14 de fevereiro. Em Portugal o Ministério dos Negócios Estrangeiros declarou oficialmente em 11 de fevereiro que a COVID-19 está presente no país, mas não anunciou nenhum impacto na prática. No entanto, o Ministério do Trabalho declarou no dia seguinte que o coronavírus existe "em todo os países que não têm regulamentação" e que nenhum "aedes", que era identificado em 11 de fevereiro, tinha entrado na corrente sanguínea.

Em 9 de março, autoridades russas confirmaram que o paciente não está hospitalizado e que está em quarentena no Hospital Pulkai.

Mais dois casos de COVID-19 foram anunciados em 9 de março, mas foram considerados casos de "propaganda" e não de risco.

Em 9 de maio, autoridades russas confirmaram os confirmados em todo país.

Use case 2 (Top-p nucleus sampling)

The best text among the 3 generated is:

A Rússia está mais perto de se tornar o primeiro país a iniciar a distribuição de uma vacina contra o coronavírus para a população. O país anunciou hoje que concluiu parte dos testes clínicos necessários para comprovar a eficácia da imunização desenvolvida por iniciativa do governo russo. A expectativa é de que a distribuição comece já em agosto.

A Organização Mundial de Saúde (OMS) estima que, no final de dezembro de 2015, havia pelo menos 50 milhões de pessoas infectadas com o vírus na Rússia e Ucrânia. Estimativas semelhantes foram feitas pela OMS sobre os casos registrados desde a década passada. Em março de 2018, o Ministério da Saúde confirmou a descoberta de um novo coronavírus e recomendou aos cidadãos russos manter a disposição para evitar viagens ao exterior após o início das aulas médicas ou exames médicos.

Em novembro de 2016, o Ministro da Saúde russo confirmou que todos os indivíduos infectados poderiam ser testados em março; assim como seus familiares e amigos, eles podem fazer teste clínico em uma unidade cirúrgica do hospital da cidade em que são colocados, caso necessário. Também foi anunciado no mesmo dia que o Ministério do Trabalho revelou que as autoridades russas estão trabalhando em conjunto visando reduzir o número de mortes causadas pelas epidemias. A agência informou que a Agência Nacional de Vigilância Sanitária Russa (Anvisa) começou a monitorar a pandemia através de máscaras faciais nos hospitais.

Em abril de 2019, o Ministério da Saúde divulgou que 582 mil pessoas haviam sido infetadas com o vírus no país entre janeiro de 2019 e maio de 2020. Cerca de 370 mil desses pacientes estariam diretamente relacionados à doença.

A Rússia também tem planos promissores para produzir vacinas que sejam eficazes contra o coronavírus, incluindo a vacina anti-SIDA e antivirais (ver Lista Vermelha da OMS).

Em 1º de julho de 2017, a Organização Mundial de Saúde lançou uma nota oficial alertando que "um grande aumento pode vir da necessidade de medidas preventivas necessárias" antes do início das aulas médicas em escolas públicas nas cidades ucranianas.

Em outubro de 2017, a Secretaria Municipal de Saúde ucraniana publicou uma nota oficial informando que os profissionais responsáveis ​​pela coordenação da vacinação deveriam estar preparados, bem como suas famílias e amigos durante a realização de exames adicionais para determinar sua saúde mental.

Em 30 de junho de 2019, o Ministério da Saúde lançou um comunicado afirmando que três grupos escolares teriam dificuldades de administrar adequadamente o vacina contra o coronavírus na Rússia.

Conclusion

We are the first, fortunately surprised by the efficiency of fine-tuning in Portuguese an English pre-trained transformer-based language model like GPT-2 small.

In about 1 day using 1 GPU and a little over 1 GB of Portuguese texts, we managed to obtain a GPorTuguese-2 capable of generating contextual Portuguese texts of a level comparable to that of the GPT-2 used by OpenAI in 2019.

Happy.

The next step would be to apply our fine-tuning method to most recent NLP models like GPT-3, BART, T5 or Reformer. Let’s do it?

Annex | Other articles about fine-tuning GPT-2 to another language

About the author: Pierre Guillou is an AI consultant in Brazil and France, Deep Learning and NLP researcher in the AI Lab (Unb), and professor of Artificial Intelligence (UnB). Please contact him via his Linkedin profile.

--

--