Image Classification with tf.keras (Introductory Tutorial)

Learn the basics of tf.keras for image classification

Binita Gyawali Regmi
Analytics Vidhya
5 min readJan 11, 2021

--

TensorFlow is one of the most popular open-source machine learning frameworks which has played a vital role in the rise of machine learning especially in industry, thanks to its rich ecosystem varying from TensorFlow Lite, which is used for deploying models in mobile phones, to TensorFlow Serving, which is used for serving TensorFlow models in production. It is developed and supported by Google. tf.keras is a high-level framework built on top of TensorFlow and it is designed in such a way that we can develop and train the machine learning models using just a few lines of code.

In this tutorial, we are using tf.keras.

At first, let’s import relevant libraries, sub-packages, modules, and classes.

Now, let’s download the dataset of flowers (around 3700) from the dataset URL. The dataset of flowers has five different classes, one for each flower, and each type of flower is kept in a different folder with its name as the folder’s name.

Let's peek at a sample image by going into roses folder.

No wonder, we get a beautiful rose here with water drops dripping through its petals. :)

Back to machine learning now. Let's create the parameters for training. Batch size is the number of images to be taken in a batch. Image height and width give the dimension of the image to be used for training and prediction.

Here, we have used 80% of data for training and 20% of data for validation. tf.keras makes it easier using its API to create training and validation datasets.

class_names attribute of the dataset (either training or validation) gives the list of the class names, which is later used in identifying classes from the maximum index value (z value or probability value). Let’s see all the classes in our flower classification example.

class_names

Let's visualize the images in a 3x3 collage.

Let’s use two methods when loading data — caching and prefetching. Caching keeps the images in memory after they’re loaded to memory during the first epoch of training. This ensures the dataset does not become a bottleneck while training the model. Prefetching overlaps data preprocessing and model execution while training.

Let’s standardize the RGB values in the range [0, 1] as neural networks prefer smaller values using the rescaling layer.

Now, we create the model using Sequential API of keras. The API adds the layers sequentially, and as we are just stacking the layers on top of one another, the Sequential API serves our purpose. We use groups of convolution layer and pooling layer, followed by fully connected layers.

Let’s compile the model and see the summary of our model. The summary gives types of layer used, the number of layers used, no of trainable parameters, and shapes of intermediate outputs.

With 10 epochs, let’s train the model.

History for vanilla image recognition model

As we see here, the training accuracy increases linearly and approaches the perfect value whereas the validation accuracy lingers around 0.6. So, the gap between training and validation accuracy is widening. This is a sign of overfitting. Overfitting happens when the model fits training data very well, but without being able to generalize the pattern that we are trying to establish. Let’s visualize the training and validation metrics in a pair of graphs.

Training and validation metrics for 10 epochs

Large gap between training accuracy and validation accuracy, which is a sign of overfitting, needs to be decreased if we want to generalize well for the unseen data. To minimize the overfitting, we use two approaches here — data augmentation and dropout.

One of the reasons for overfitting is a small number of training data. When the dataset is small in size, the model learns from obvious noise (which is often seen in the real world). One way to augment the dataset is to transform the existing data using different transformation techniques like rotation, flip, shift, zoom, crop, etc.

Code for data augmentation:

Using tf.keras for data augmentation

Let's look at a few examples of transformed images.

Randomly transformed images

Now, let’s used the dropout technique in the model. Dropout drops out the nodes randomly in a layer where it is defined. We are dropping 20% of the nodes, so we use 0.2 as an argument in the Dropout layer.

The model creation using both data augmentation and dropout is done. Let's compile the model and train it. As we can see in the summary of the model, the augmentation and dropout layers have been added. As the training takes 10/15 minutes depending on the hardware it is being training on, why don’t we grab a cup of tea in the meantime?

History of training after applying data augmentation and dropout

As noticeable in the history, the overfitting has decreased after doing data augmentation and applying dropout. The narrowing of the accuracy gap shows a decrease in overfitting. Let’s plot similar graphs for metrics.

To see the training and validation graphs go together is a moment of bliss. Happy, happy :)

This is a gentle introduction about creating machine learning models using tf.keras.

After creating the model, we need to test it in our real-world data. Let’s get a random image of a sunflower from the web (you can search yourself and replace the image_url with your URL) and test if our model predicts it correctly.

Sunflower image using which we are testing our model
Prediction and its probability

Yay, our model predicted it as “sunflowers” with a probability of ~95%. Cool, isn’t it?

Thanks for making it to the end. This article will be followed by other TensorFlow tutorials on computer vision. Stay tuned!

--

--