Training on a TPU in parallel using PyTorch XLA

Train your model now many times faster using all TPU cores at once!

Abhishek Swain
TDS Archive
Published in
3 min readMay 9, 2020

--

Image taken from Google cloud blog on TPUs

Taken from the Kaggle TPU documentation:

TPUs are now available on Kaggle, for free. TPUs are hardware accelerators specialized in deep learning tasks. They are supported in Tensorflow 2.1 both through the Keras high-level API and, at a lower level, in models using a custom training loop.

You can use up to 30 hours per week of TPUs and up to 3h at a time in a single session.

The Kaggle documentation only mentions how to train a model on a TPU with Tensorflow, but I wanted to do it using PyTorch. PyTorch has XLA which is what we are gonna use to run our code on TPU. Anyway, the problem I faced was there was no single source of information about how to do it. It was all scattered all over the place!

I did quite a bit of research and found this amazing kernel by Abhishek Thakur. He explained how to train on a TPU parallelly on all it’s 8 cores. He even has a youtube video that explains training on a TPU. Check it out here https://www.youtube.com/watch?v=2oWf4v6QEV8.

Okay, so let’s begin!

First, we need to install torch xla, for that all you need to do is copy, paste these two lines on colab or kaggle and run them:

Next are important imports:

Required XLA imports

So, I used the TPU to train my model for a Kaggle competition. It’s a simple one called: Plant Pathology 2020. You can check it out. I am going to skip over the data preprocessing, modeling code as that is a topic for another blog. Here, we are only concerned with running the model on TPU. I will attach the link to the complete Ipython notebook for you.

So jumping straight to the training code, I will highlight the things needed for running the model parallelly. The first important thing is a distributed sampler for our Dataloader:

xm.xrt_world_size() retrieves the number of devices that are taking part in the replication. (basically the number of cores)

xm.get_ordinal() retrieves the replication ordinal of the current process. The ordinals range from 0 to xrt_world_size()-1

The next thing is to train the model parallelly, traditional DataLoader has to be made into a ParallelLoaderobject and then passed into the training function. For this we do, pl.ParallelLoader(<your-dataloader>, [device])

The device here is device = xm.xla_device() . We are simply, specifying were to send our model to run. In this case, it is a TPU or as PyTorch likes to call it and XLA device(If your’e a PyTorch user then you can think of it as similar to torch.device('cuda') used to send the tensors to a GPU)

Torch.xla has it’s own specific requirements. U can’t simply make a device using xm.xla_device() and pass the model to it.

With that:

  1. Optimizer has to stepped with xm.optimizer_step(optimizer).
  2. You have to save the model with xm.save(model.state_dict(), '<your-model-name>)
  3. You have to use xm.master_print(...) to print.
  4. For parellel training we first define the distributed train & valid sampler, then we wrap the dataloaders in torch_xla.distributed.parallel_loader(<your-data-loader>) and create a torch_xla.distributed.parallel_loader object as I explained above.
  5. While passing it to training and validation function we specify this para_loader.per_device_loader(device). This is what you will iterate over in the training function, i.e. we pass a parelleloader and not a dataloader (for parellel training only).

With all of the specifiactions in place now you’re ready to train your model on a TPU in kaggle, not just kaggle but on colab too. I know it doesn’t quite make complete sense. But this is just a preliminary explanation for you where you can come and refer the specifics. Once you look at the complete code you will understand everything. Best of luck! as promised here is my complete Kernel for Kaggle :)

Also here is my Kaggle kernel link https://www.kaggle.com/abhiswain/pytorch-tpu-efficientnet-b5-tutorial-reference. If you found this useful you can upvote it! You can also go there and directly run it and see for yourself the magic.

Let me tell you something, I made this to keep as a reference for myself but decided to share it. Let this be a stop for you in your journey of deep learning :)

Torch XLA documentation: Torch XLA

--

--

TDS Archive
TDS Archive

Published in TDS Archive

An archive of data science, data analytics, data engineering, machine learning, and artificial intelligence writing from the former Towards Data Science Medium publication.

Responses (3)