Efficiently using TPU for image classification

Leveraging Tensorflow and TPUs to build a flower classification system.

Dimitre Oliveira
The Startup
8 min readOct 13, 2020

--

A while ago (early 2020), Kaggle added TPUs as a hardware option to be used on its kernel environment, shortly after, they launched the competition “Flower Classification with TPUs” and challenged the community to use a large dataset of images and extract the most of this powerful hardware.

The objective of this article is to provide basic knowledge about the integration between Tensorflow and TPUs so that you can build a good image classifier baseline that can take advantage of the TPU processing power.

So, what are TPUs?

TPU is short for “Tensor processing unit”, TPUs are powerful hardware accelerators specialized in deep learning tasks. They were developed (and first used) by Google to process large image databases, such as extracting all the text from Street View.

By that time the Tensorflow team had also done a new release of its framework, the (TF 2.1) release was focused on TPUs and it supports both through the Keras high-level API and at a lower level, in models using a custom training loop.

I am not going into too many details about what is a TPU and how it works, but if you want to learn more check out this awesome documentation that the Kaggle team has created, instead I am going to focus on how to leverage it with Tensorflow.

About the dataset

The competition dataset has 16465 images for the training set and 7382 images for the test set, the data was distributed over 104 different classes, this means that for each image the task was to predict from which class it belongs among the 104. The classes were very unbalanced, this is another thing to keep in mind.

Bar plot with the data label distribution.

Now let's take a look at some image samples

Dataset image samples.

As we can see the images had a very good quality, in fact, the data was provided in multiple resolutions (192x192, 224x224, 331x331, and 512x512). To be even more efficient the data was provided in the TFRecords format, TFRecords is a very efficient format that Tensorflow provides to store and load data, it is especially useful when you are using a TPU, if you want to learn more here is a very good video that explains TFRecords.

Source: https://www.kaggle.com/docs/tpu

Experiment setup

From the image above it is clear that a TPU is no ordinary hardware, it is far more powerful than most of the GPUs you may find available, Kaggle provides TPU v3–8, it has 420 teraflops and 128GB of RAM with its 8 cores, this is why for this experiment I am going to use images with 512x512 resolution and the model will be an EfficientNetB6, the B6 version of EfficientNet has 43 million parameters, this alone can be prohibitive for some hardware, but together with the images at 512x512 resolution, this becomes a good use case for a TPU.

So where do we begin?

The regular process to train a model on a TPU would be roughly something like this:
1. Convert the dataset to TFRecords.
2. Initialize the TPU system.
3. Get the dataset address.
4. Build a Tensorflow data pipeline.
5. Create and train the model.

We are going to use the Kaggle environment it already comes with all the required libraries, but you can also use Google Colab (currently Colab has only TPU v2).

1. Convert the dataset to TFRecords

We can skip this part since the competition already provides us the TFRecords files, but if you want to learn how it is done check this notebook guide.

2. Initialize the TPU system

When you are using TPUs you don’t actually code at the same virtual machine that the TPU hardware is located, instead, your VM will communicate with the VM that hosts the TPU, for this reason, you need to initialize the remote TPU system, but whit Tensorflow this is actually very simple.

With just 3 lines of code, we initialized the TPU system, we can also see the TPU address.

3. Get the dataset address

For the reason described above, we need to get the dataset address so the virtual machine that hosts the TPU can know where to fetch the data that it will process.
Kaggle has a library “kaggle_datasets” that makes this process very easy, with just one line of code we can get the dataset address.

In the code above “tpu-getting-started” is the name of the dataset, and “/tfrecords-jpeg-512x512” is the path to the 512x512 TFRecords files.

4. Build a Tensorflow data pipeline

We will use the Tensorflow data API to build the data pipeline that will feed the TPU.
Our pipeline will do the following steps:

1. Convert the TFRecods files to a TF dataset.
2. Parse and decode the dataset.
3. Apply data transformations like data augmentation.
4. Apply dataset operations like shuffle, cache, or batch.

4.1. Convert the TFRecods files to a TF dataset

This line will get all the TFRecords files and convert them to a TFRecordDataset. In the code above “filenames” would come from loading the files from the dataset address, this can be done with the code below:

This line will load all files with the “.tfrec” extension inside the “train” folder to a list.

4.2. Parse and decode the dataset

This code will map a function to the TFRecordDataset that we created, this function will first parse the dataset to samples with “string” and “int64” formats, then we can apply other regular functions, here we decoded the JPEG files, normalized them so it will have values in the [0, 1] range, and finally reshaped them to the expected shape (512x512x3).

4.3. Apply data transformations like data augmentation

We can also use the map function to apply data augmentations, in the function above we are applying some basic image data augmentation, flips, rotations, and brightness adjustment.
There are many more options at the Tensorflow image module.

4.4. Apply dataset operations like shuffle, cache, or batch

Finally, Tensorflow provides some useful operations with the dataset API, let’s go through what each one does.

repeat: Repeats this dataset, so it can be iterated over infinitely.
shuffle: Randomly shuffles the elements of this dataset.
batch: Combines consecutive elements of this dataset into batches.
cache: Caches the elements in this dataset in memory.
prefetch: Prepare more elements while the current is being processed.

5. Create and train the model

Now that we have our data pipeline ready we can create the model that will use it.

As said at the beginning, we are using EfficientNetB6 with input images of 512x512, we are also going to use this model with the ImageNet pre-trained weights, fortunately, the Tensorflow applications module already provides the model and the weights, on top of that we add a simple global average pooling layer and the softmax classifier layer with 104 classes.

For the model to be able to train inside the TPU scope we need to define a strategy, with Tensorflow we can do that with just one line of code, then we just create the model inside that scope, and it is ready to be trained.

When we are using pre-trained weights it is usually a good idea to either use a warmup pre-train step or a learning rate schedule with a warmup phase.

Learning rate schedule with a warmup phase.

In the image above the X-axis is the epoch count, and the Y-axis is the learning rate value. During the warmup phase it starts at a very low value and linearly increases, then the learning rate exponentially decays overtime.

After this, the training happens as usual.

This basically covers all the process until the training step, after that you can do things, as usual, in fact, the training step code itself is the same as you would write for other hardware if you are using a TF dataset.

Here is the complete code with some cleaning

I have written 3 notebooks for this competition, they have awarded me the TPU star 3rd place prize at that time.

Flower Classification with TPUs — EDA and Baseline (A simple baseline, very close to the code used here)
Flower with TPUs — Advanced augmentations (How to better control data augmentation with TF datasets pipelines)
Flower with TPUs — K-Fold optimized training loop (K-Fold training using a optimized training loop for TF)

References and additional links:

- Keras and modern convnets, on TPUs.
- Tensor Processing Units (TPUs).
- Google cloud TPU performance guide.
- Getting started with 100+ flowers on TPU (notebook).

One last acknowledgment I would like to make here is to highlight the effort Google has been doing towards A.I. democratization, making freely available all these resources, through Kaggle or Google Colab, I could not say how much helped me being able to do countless hours of experimentations with their resources.

--

--

Dimitre Oliveira
The Startup

Machine Learning Engineer at Intuition Machines |Solutions Architect at Virtus | Google Developer Expert on Machine | Kaggle Grandmaster