🌓 From TensorFlow to PyTorch

Thomas Wolf
Aug 9, 2019 · 7 min read

Friends and users of our open-source tools are often surprised how fast 🚀 we reimplement the latest SOTA pre-trained TensorFlow models to make them accessible for everyone in our libraries like PyTorch-Transformers 👾 or PyTorch-pretrained-BigGAN 🦋

In this post, you’ll learn the main recipe to convert a pretrained TensorFlow model in a pretrained PyTorch model, in just a few hours.

We’ll take the example of a simple architecture like OpenAI GPT-2 🦄

Doing such a conversion assumes a good familiarity with both TensorFlow and PyTorch but it’s also one of the best ways to get to know better both frameworks!

Looking at the scope structure 🔎

TensorFlow checkpoints are usually composed of three files named XXX.ckpt.data-YYY , XXX.ckpt.index and XXX.ckpt.meta :

A trained NLP model should also be provided with a vocabulary to associate the tokens to the embeddings indices (here encoder.json and vocab.bpe). We won’t talk in too many details about vocabulary and tokenizer here since you can usually directly reuse their original python code with minor modifications.

First, we can have a look at the hyper-parameters file: hparams.json. It contains a few hyper-parameters like the number of layers/heads and so on:

We can reuse this JSON file in a configuration class for our model.

Now, let’s have a look at the structure of the model. Starting from now, you’ll need to have TensorFlow installed on your computer (can be the CPU version). Once TensorFlow is set up, open a python interpreter to load the checkpoint to inspect the saved variables:

The result is a (long) list of all the variables stored in the checkpoint with their name and shapes:

Variables are stored as Numpy arrays that you can load with tf.train.load_variable(name).

Now, what we are particularly interested in here are the path-like names of the variables like model/h0/ln_1/b which reflects the organization of TensorFlow variables in scopes.

Here is our first secret:

To build our PyTorch model as fast as possible, we will reuse exactly the same organization: for each sub-scope in the TensorFlow model, we’ll create a sub-class under the same name in PyTorch.

This will let us load weights easily by jointly iterating on scopes & classes.

As you can see, GPT-2 has three modules at the root of the model (at the end of the list): model/wte, model/wpe and model/ln_f, and the rest of the model is composed of a series of identical modules hXX, each comprising a self-attention sub-module attn , a feed-forward module mlp and two layer-normalization modules ln_1 and ln_2 .

Now that we know how the model is organized, let’s build our PyTorch model with a hierarchy that reproduces this organization of scopes.

Building the PyTorch model skeleton 👩‍🎨

As you can see, we’ve given our main sub-modules names (wte, wpe, h, ln_f) that are identical to the first-level scopes of the variables we saw in the TensorFlow checkpoint.

We can also write the code for our forward pass by converting the code for the main model from TensorFlow operations to PyTorch operations:

Here I’ve removed the hidden-state caching logic (past) to simplify the gist. It’s an option to speed up inference

Now we dive deeper in the hierarchy, continuing to build our PyTorch model by adapting the rest of the TensorFlow code. Here is another example comparing the TensorFlow code for a “Block” module:

To the PyTorch equivalent nn.Module class:

Here again, the name of the class attributes containing the sub-modules (ln_1, ln_2, attn, mlp) are identical to the associated TensorFlow scope names that we saw in the checkpoint list above. Doing that ensures that the PT hierarchical attributes structure will be identical to the TF scope structure.

Beware of the details — section I 🕵️

The computation flow

It’s a good opportunity to dive in the internals of both frameworks to see how each operation is made under the hood. One example: TensorFlow & PyTorch layer normalizations are slightly different from each other (go check them out!) so I usually reimplement layer normalization from scratch in PyTorch.

The initialization and defaults

Loading the weights 🏋️

Having the same models' organization make the loading very easy:

We just jointly iterate on both the path-like names of TensorFlow variables & our PyTorch model attributes.

A commented loading function for GPT-2 looks like this:

Let’s talk about a few things to keep in mind at this stage 👇

Beware of the details — section II🕵️

Transposing tensors from TensorFlow to PyTorch

The main cases where this happens in practice are Keras modules like tf.layer.dense whose kernel is the transposed of PyTorch’s nn.Linear weights.

This transposition issue can be especially tricky to detect for square matrices which bring us to our last section 👇

The final step —️ comparing the models 👭

Comparing hidden-states 🎼

a careful comparison of both models!

The way I usually do it is by starting from one script running the TensorFlow model provided by the authors of the original implementation and:

  • modify the TensorFlow model to output the hidden-states at regular locations along the depth of the model,
  • modify our PyTorch model to output the hidden-states at the same regular locations along the depth of the model,
  • load the PyTorch model in parallel with the TensorFlow model and run them on the same inputs,
  • compare their behaviors during a forward pass to detect where an error may have been made.

You should take care of deactivating the DropOut modules and all nondeterministic modules to ensure maximal compatibility.

If your script is a fine-tuning script and your model contains weights which are newly initialized, you should take care of fully initializing the PyTorch model from the newly initialized TensorFlow model for good comparison. Here is an example of this process during the reimplementation of XLNet in pytorch-transformers where the new TensorFlow model is saved and loaded in PyTorch.

I usually compare the max absolute difference between the hidden-states after each layer of the models on a few real-life inputs:

Comparing on a down-stream task 🚣

This task can be quite long as you will need to reproduce the pre-processing, optimization and post-processing of the original author’s work.

In our experience, a discrepancy at this stage, in pretty much every case, doesn’t come from a difference inside the models but from a discrepancy in the way the inputs are prepared, in the optimization parameters (one of the most often over-looked ones being the batch size) or in the post-processing and evaluation metrics.

That’s all folks👭

This method has a few limits:

  • the model may end up having a deeper hierarchy than necessary. In this case, you can rewrite the model to reduce the number of classes and use a mapping between the TensorFlow variables and the PyTorch attributes 🗺
  • the model is sometimes implemented with operations that are fast in TensorFlow or TPU (e.g. multiplication with one-hot matrices) but may be suboptimal in PyTorch. Here again, some rewriting and conversion afterward can help speed up the resulting model in some cases 🏎
  • You need access to the TensorFlow code for the conversion. It’s possible to convert a TensorFlow model without access to the code, e.g. a model only available on TensorFlow Hub but it’s a far more difficult process. In PyTorch-pretrained-BigGAN we did that by inspecting the raw computation graph and guessing the high-level operations involved 🙃

👾 For detailed code examples of this process, you can have a look at the various models implemented in PyTorch-Transformers.

… and if you feel like adding one of your own, we will probably be more than happy to welcome a Pull Request on the repository! Just ping us before to be sure we are not already working on it 😉

HuggingFace

Stories @ Hugging Face