How to Train a Neural Network Classifier on ImageNet using TensorFlow 2

Pedro Sandoval-Segura
Analytics Vidhya
Published in
2 min readDec 23, 2020

--

A sampling of images from the ImageNet dataset, where each image is one of 1000 classes.

Image classification is a classic problem in computer vision. Today, state-of-the-art models for this problem use neural networks, which means that implementing and evaluating these models requires the use of a deep learning library like PyTorch or TensorFlow.

You can easily find PyTorch tutorials for downloading a pretrained model, setting up the ImageNet dataset, and evaluating the model. But I could not find a comprehensive tutorial for doing the same in TensorFlow. In this article, I’ll show you how.

Requirements

  • Python 3
  • TensorFlow 2.3.1 (install with pip3 install tensorflow==2.3.1)
  • TensorFlow Datasets (install with pip3 install tensorflow-datasets==4.1.0)
  • CUDA and cuDNN (since I’m using an NVIDIA GPU)
  • ILSVRC2012_img_train.tar and ILSVRC2012_img_val.tar which you can download from here. Note that these archives are typically stored in read-only memory (for multiple users) since they require ~156 GB in storage space.

Overview

There are essentially 3 steps which we’ll work through: preparing the ImageNet dataset, compiling a pretrained model, and finally, evaluating the accuracy of the model.

First, let’s import some packages:

Now, we’ll download ImageNet labels and specify where our ImageNet archive files are located. In particular, data_dir should be the path where ILSVRC2012_img_train.tar and ILSVRC2012_img_val.tar are located. And write_dir should be the directory where we’d like to write extracted image content. Make sure your write_dir directory containsextracted, dowloaded, and data directories.

The key here is that we called tfds.load with keyword arguments to the download_and_prepare call, specifying where our archive files were located and where extracted records should be placed.

Now, because pretrained classification models take 224 x 224 images as input, we’ll need to do some preprocessing of our data. Here we’ll use mobilenet_v2.preprocess_input(i), but if you’re using a different model, you can replace this call. For example, if I were using a VGG-16, I’d instead call vgg16.preprocess_input(i).

Next, we compile a model of our choice. In my case, a MobileNet V2.

Finally, because of the way we’ve set up our dataset, we can evaluate this model on the training data and print the accuracy using a few lines of code!

With a functioning pretrained classifier, you can now finetune the model to fit the needs of your classification problem. The full implementation is below, but it can also be found on my GitHub.

Full implementation of downloading a pretrained model, setting up the ImageNet dataset, and evaluating the model in Tensorflow 2

--

--