Published in


PyTorch / XLA is now Generally Available on Google Cloud TPUs

Authors: Ailing Zhang (FB), Joe Spisak (FB), Geeta Chauhan (FB), Vaibhav Singh (Google), Isaack Karanja (Google)

The PyTorch-TPU project was announced at the PyTorch Developer conference 2019 and originated from a collaboration among engineers and researchers at Facebook, Google, and Salesforce Research. The overarching goal of the project was to make it as easy as possible for the PyTorch community to leverage the high performance capabilities that Cloud TPUs offer while maintaining the dynamic PyTorch user experience. To enable this workflow, the team created PyTorch / XLA, a package that lets PyTorch connect to Cloud TPUs and use TPU cores as devices. Additionally and as part of the project, Colab enabled PyTorch / XLA support on Cloud TPUs. Fast forward to September 2020, and the PyTorch / XLA library has reached general availability (GA) on Google Cloud and supports a broad set of entry points for developers. It also has a fast-growing community of researchers and enterprise users training a wide range of models accelerated with Cloud TPUs and Cloud TPU Pods including researchers and engineers at MIT, Salesforce Research, Allen AI and elsewhere.

PyTorch Developer Conference 2019 | PyTorch on Google Cloud TPUs — Google, Salesforce, Facebook

What’s new for the GA release?

With PyTorch / XLA GA, PyTorch 1.6 is officially supported on Cloud TPUs. Other notable new features include:

  • Support for Intra-Layer Model Parallelism: All-Reduce operation can now be performed with multiple operation types and groups. More communication primitives have been added to enable interesting applications such as distributing large embedding tables over multiple TPU cores;
  • Additional XLA ops: As PyTorch / XLA usage grew across an ever-widening range of new models, users asked PyTorch ops to be mapped to XLA, and we responded. Since the beta (1.5) release, we have incorporated XLA lowerings for replication_pad1d, replication_pad2d, max_unpool2d, max_unpool3d, and other ops;
  • Better Experience in Colab / Kaggle Notebooks: Now you no longer need to run the script on Colab / Kaggle before you start training; and
  • Support within Deep Learning VM Images: Google Cloud Platform provides a set of Deep Learning Virtual Machine (DLVM) images that include everything you need to get started with various deep learning frameworks, including PyTorch. PyTorch / XLA 1.6 is now pre-installed in DLVM and optimized for Cloud TPUs. The official PyTorch 1.6 is also pre-installed in the same Conda environment. Follow this user guide to get started.

What models are supported?

PyTorch / XLA has been used to train numerous deep learning models on Cloud TPUs. Reference implementations are available for a diverse set of models such as:

In most cases, training these models on Cloud TPUs requires very few code changes. You can find official tutorials on Google Cloud here: ResNet-50, Fairseq Transformer, Fairseq RoBERTa, DLRM, PyTorch on Cloud TPU Pods. Check out the PyTorch / XLA GitHub repository for examples of other model architectures trained on Cloud TPUs.

How does PyTorch / XLA work?

PyTorch / XLA works using a ‘lazy tensor’ abstraction. With lazy tensors, the evaluation of tensor operations are deferred until the result of that operation is required (control/reporting). Up until that point, the operations are captured as an Intermediate Representation (IR) graph. Once results are required, these IR graphs are then compiled via XLA and sent to TPU cores for execution. This XLA compilation can also target CPU and GPUs.

Additional technical details about the approach are available on GitHub.

What code changes are needed to get started?

To start training, you need to create a Google Compute Engine VM (user VM) with the PyTorch / XLA image and a separate Cloud TPU Instance.

Once the user VM and the Cloud TPU instance are created, you can set the appropriate conda environment and set the XRT_TPU_CONFIG environment variable to point to the Cloud TPU instance:

At this point, you are ready to start training your model on a Cloud TPU! Let’s look at some sample code for training a “toy model” and notice the elements unique to PyTorch / XLA:

The lines highlighted above are: import statements for PyTorch / XLA components, the method to access the XLA device abstraction, and the parallel dataloader to facilitate overlapped data transfer and Cloud TPU execution. Also note the optimizer_step method, which performs the all-reduce operation followed by the parameters update (optimizer.step) behind the scenes. (GPU and CPU device types are also supported by PyTorch / XLA with no change in the code. Only an XRT_TPU_CONFIG variable is set differently to target these other hardware platforms.)

In contrast, here is an example code for training the same model on 4 GPU devices using PyTorch (without PyTorch / XLA):

As you compare the two code samples above, you may note the following: 1) The model code required no changes to execute on Cloud TPUs; 2) In the training loop, there are similarities between the PyTorch API and PyTorch / XLA to wrap and transfer the model object to the corresponding device abstraction (CUDA in case of GPUs and xla_device in case of Cloud TPUs). There is also an additional element for parallel data loading as described above.

Training on Cloud TPU Pods

PyTorch / XLA also provides utilities to scale the training you just executed on an individual Cloud TPU (v3–8, for example) to a full Cloud TPU Pod (v3–2048) or any intermediate-sized Cloud TPU Pod slice in between (e.g. v3–32, v3–128, v3–256, v3–512 and v3–1024), this scaling is done using xla_dist wrapper:

In order to set up distributed training, you create the Cloud TPU Pod slice of the desired size and a corresponding instance group with an appropriate number of VMs (matching the number of TPU cores divided by eight). Also, make sure that the training dataset is accessible to the respective virtual machines (workers) in the instance group. To start the training, launch your training script with xla_dist wrapper as shown above; xla_dist will orchestrate the Cloud TPU mesh configuration and execute the training script from each of workers.

A more detailed guide on how to get started with Cloud TPU Pods is available here. Further details are available on GitHub.

Getting Started

Colab notebooks provided in the official PyTorch / XLA repository are an excellent place to start exploring PyTorch / XLA on Cloud TPUs. Once you are familiar enough with the API, you can start to work with your own models following the setup provided in official examples.


This project would not have been possible without contributions from Alex Suhan (Prior affiliations with both Google and FB), Bryan McCann (Salesforce Research), Carlos Escapa (FB), Jin Young Daniel Sohn (Google), Davide Libenzi (Formerly at Google), Jack Cao (Google), Mike Ruberry (FB), Shauheen Zahirazami (Google), Shauna Kelleher (FB), Soumith Chintala (FB), Taylan Bilal (Google), Vishal Mishra (Google), Woo Kim (FB), Zach Cain (Google), and Zak Stone (Google).

The announcement from the Google team can be found here.





An open source machine learning framework that accelerates the path from research prototyping to production deployment

Recommended from Medium

Reliable Android Testing using LinkedIn’s Test Butler Tutorial

What Are Features And Capabilities In SAFe®?

Rich Web-Base Applications

Second Project Phase two : Sinatra

The raw joy I feel when coding in C

15 Python Tricks to make your Life Easier

The Truth About Effort Estimation

Build SaaS for change. . .not perfection

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store


PyTorch is an open source machine learning platform that provides a seamless path from research prototyping to production deployment.

More from Medium

Solliance makes headlines with cryptocurrency news analysis platform powered by Azure Machine…

No GPUs Available, What Now?

A Vertex AI TensorBoard alternative for smaller budgets (Part 1)

Google’s TPU Research Cloud! Free TPU hardware for Deep learning Projects…