🌓 From TensorFlow to PyTorch

Thomas Wolf
Aug 9, 2019 · 7 min read

Friends and users of our open-source tools are often surprised how fast 🚀 we reimplement the latest SOTA pre-trained TensorFlow models to make them accessible for everyone in our libraries like PyTorch-Transformers 👾 or PyTorch-pretrained-BigGAN 🦋

In this post, you’ll learn the main recipe to convert a pretrained TensorFlow model in a pretrained PyTorch model, in just a few hours.

We’ll take the example of a simple architecture like OpenAI GPT-2 🦄

Doing such a conversion assumes a good familiarity with both TensorFlow and PyTorch but it’s also one of the best ways to get to know better both frameworks!

Looking at the scope structure 🔎

The first step is to retrieve the TensorFlow code and a pretrained checkpoint. Let’s get them from OpenAI GPT-2 official repository:

TensorFlow checkpoints are usually composed of three files named XXX.ckpt.data-YYY , XXX.ckpt.index and XXX.ckpt.meta :

A trained NLP model should also be provided with a vocabulary to associate the tokens to the embeddings indices (here encoder.json and vocab.bpe). We won’t talk in too many details about vocabulary and tokenizer here since you can usually directly reuse their original python code with minor modifications.

First, we can have a look at the hyper-parameters file: hparams.json. It contains a few hyper-parameters like the number of layers/heads and so on:

Now, let’s have a look at the structure of the model. Starting from now, you’ll need to have TensorFlow installed on your computer (can be the CPU version). Once TensorFlow is set up, open a python interpreter to load the checkpoint to inspect the saved variables:

The result is a (long) list of all the variables stored in the checkpoint with their name and shapes:

Variables are stored as Numpy arrays that you can load with tf.train.load_variable(name).

Now, what we are particularly interested in here are the path-like names of the variables like model/h0/ln_1/b which reflects the organization of TensorFlow variables in scopes.

Here is our first secret:

To build our PyTorch model as fast as possible, we will reuse exactly the same organization: for each sub-scope in the TensorFlow model, we’ll create a sub-class under the same name in PyTorch.

This will let us load weights easily by jointly iterating on scopes & classes.

As you can see, GPT-2 has three modules at the root of the model (at the end of the list): model/wte, model/wpe and model/ln_f, and the rest of the model is composed of a series of identical modules hXX, each comprising a self-attention sub-module attn , a feed-forward module mlp and two layer-normalization modules ln_1 and ln_2 .

Now that we know how the model is organized, let’s build our PyTorch model with a hierarchy that reproduces this organization of scopes.

Building the PyTorch model skeleton 👩‍🎨

It’s time to have a look at the TensorFlow code it-self. We’ll start with the code for the main model and reproduce the general organization in our PyTorch main model class:

As you can see, we’ve given our main sub-modules names (wte, wpe, h, ln_f) that are identical to the first-level scopes of the variables we saw in the TensorFlow checkpoint.

We can also write the code for our forward pass by converting the code for the main model from TensorFlow operations to PyTorch operations:

Now we dive deeper in the hierarchy, continuing to build our PyTorch model by adapting the rest of the TensorFlow code. Here is another example comparing the TensorFlow code for a “Block” module:

To the PyTorch equivalent nn.Module class:

Here again, the name of the class attributes containing the sub-modules (ln_1, ln_2, attn, mlp) are identical to the associated TensorFlow scope names that we saw in the checkpoint list above. Doing that ensures that the PT hierarchical attributes structure will be identical to the TF scope structure.

Beware of the details — section I 🕵️

The computation flow

When you convert TensorFlow code to PyTorch code, you have to be attentive to reproduce the exact computation workflow of the TensorFlow model in PyTorch. For instance, you should take care of reimplementing all the operations, even the ones not associated to a Variable (i.e. not visible in the checkpoint), add the dropout modules at same places than the original ones and carefully check how to convert each TensorFlow method in an equivalent PyTorch operation.

It’s a good opportunity to dive in the internals of both frameworks to see how each operation is made under the hood. One example: TensorFlow & PyTorch layer normalizations are slightly different from each other (go check them out!) so I usually reimplement layer normalization from scratch in PyTorch.

The initialization and defaults

It’s also important to check default parameters of each module like epsilons and make sure you are using the same ones in PyTorch than the TensorFlow. Be especially careful about defaults values that may not be visible.

Loading the weights 🏋️

Once the code conversion step is finished and you can run a forward pass on dummy input without any errors with your newly defined PyTorch model, it’s time to load the TensorFlow weights in the newly created model 🐣

Having the same models' organization make the loading very easy:

We just jointly iterate on both the path-like names of TensorFlow variables & our PyTorch model attributes.

A commented loading function for GPT-2 looks like this:

Let’s talk about a few things to keep in mind at this stage 👇

Beware of the details — section II🕵️

Transposing tensors from TensorFlow to PyTorch

Some TensorFlow operations operate on weights that are transposed with regards to their PyTorch counter-part (or vice-versa 😉). In this case, your weights loading method should take care of transposing the weights when loading them.

The main cases where this happens in practice are Keras modules like tf.layer.dense whose kernel is the transposed of PyTorch’s nn.Linear weights.

This transposition issue can be especially tricky to detect for square matrices which bring us to our last section 👇

The final step —️ comparing the models 👭

Comparing hidden-states 🎼

Now that your model runs and all the weights are initialized with their TensorFlow counterpart it is time for the most important operation:

a careful comparison of both models!

The way I usually do it is by starting from one script running the TensorFlow model provided by the authors of the original implementation and:

  • modify the TensorFlow model to output the hidden-states at regular locations along the depth of the model,
  • modify our PyTorch model to output the hidden-states at the same regular locations along the depth of the model,
  • load the PyTorch model in parallel with the TensorFlow model and run them on the same inputs,
  • compare their behaviors during a forward pass to detect where an error may have been made.

You should take care of deactivating the DropOut modules and all nondeterministic modules to ensure maximal compatibility.

If your script is a fine-tuning script and your model contains weights which are newly initialized, you should take care of fully initializing the PyTorch model from the newly initialized TensorFlow model for good comparison. Here is an example of this process during the reimplementation of XLNet in pytorch-transformers where the new TensorFlow model is saved and loaded in PyTorch.

I usually compare the max absolute difference between the hidden-states after each layer of the models on a few real-life inputs:

Comparing on a down-stream task 🚣

If your model is a pretrained model which can be fine-tuned on a down-stream task, you can further confirm the accuracy of the conversion by reproducing some results on a downstream task.

This task can be quite long as you will need to reproduce the pre-processing, optimization and post-processing of the original author’s work.

In our experience, a discrepancy at this stage, in pretty much every case, doesn’t come from a difference inside the models but from a discrepancy in the way the inputs are prepared, in the optimization parameters (one of the most often over-looked ones being the batch size) or in the post-processing and evaluation metrics.

That’s all folks👭

We’ve seen the main steps you can take to quickly and accurately reimplement a pretrained TensorFlow model in PyTorch.

This method has a few limits:

  • the model may end up having a deeper hierarchy than necessary. In this case, you can rewrite the model to reduce the number of classes and use a mapping between the TensorFlow variables and the PyTorch attributes 🗺
  • the model is sometimes implemented with operations that are fast in TensorFlow or TPU (e.g. multiplication with one-hot matrices) but may be suboptimal in PyTorch. Here again, some rewriting and conversion afterward can help speed up the resulting model in some cases 🏎
  • You need access to the TensorFlow code for the conversion. It’s possible to convert a TensorFlow model without access to the code, e.g. a model only available on TensorFlow Hub but it’s a far more difficult process. In PyTorch-pretrained-BigGAN we did that by inspecting the raw computation graph and guessing the high-level operations involved 🙃

👾 For detailed code examples of this process, you can have a look at the various models implemented in PyTorch-Transformers.

… and if you feel like adding one of your own, we will probably be more than happy to welcome a Pull Request on the repository! Just ping us before to be sure we are not already working on it 😉


Stories @ Hugging Face

Medium is an open platform where 170 million readers come to find insightful and dynamic thinking. Here, expert and undiscovered voices alike dive into the heart of any topic and bring new ideas to the surface. Learn more

Follow the writers, publications, and topics that matter to you, and you’ll see them on your homepage and in your inbox. Explore

If you have a story to tell, knowledge to share, or a perspective to offer — welcome home. It’s easy and free to post your thinking on any topic. Write on Medium

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store