Image Classification with Variable Input Resolution in Keras
Convolutional neural networks (CNN) are a type of neural network designed for image classification. For an introduction to CNNs, check out this post by Matthew Stewart.
Problem
Many CNN architectures require that all input images have the same resolution (height and width). This happens when a convolutional layer is flattened and fully connected to a dense layer. Since the network has to initialize the dense layer weights, the convolutional layer resolution must be known, requiring the input resolution to be predetermined and constant. The fallback of this is that all images, both during training and inference, must be resized to fit the resolution dimensions, creating an extra step in image preprocessing and causing a loss of resolution on large images needing to be down-sampled. This could have negative impacts on model performance, depending on the image data used and what the model is trying predict.
Variable Input Shape
There is a way to avoid specifying input dimensions when setting up a CNN, allowing for variable image resolutions during training and inference. This is done by using global pooling layers after the final convolutional layer and before any dense layers within the CNN. Global pooling reduces each channel of a convolutional layer to a single value in a way that is not dependent on the resolution of the channel. This is usually done by taking either the average or the maximum of all of the values in the channel, leaving you with a single layer of neurons, one for each channel. The number of channels in a convolutional layer is defined in the model architecture and is independent of the channel resolutions.
Tutorial
Now let’s take a look at how to train a CNN with variable image resolution using Keras. We’ll use the Imagenette dataset which can be found here. The following code was written with TensorFlow 2.1.
First, we build our neural network. We will use one of the supplied CNN models in the applications module of Keras. Don’t forget to include a global pooling layer.
Next, we need to build a batch generator class that will load and preprocess our data during training. While we do this, we will have to keep in mind some limitations to the extent of our image resolution variability. If an image is too large, we may run out of memory. To avoid this, we set a threshold for the maximum height and width we will allow our images to be. Anything above this threshold will be down-sampled. Also, TensorFlow requires that each input tensor has uniform shape. This means we will have to pad the images of each batch to match the height and width of the largest image in the batch. Batches containing single images do not need to be padded, so we can run inference on images without any padding if we run them one at a time.
All that is left is to specify our training hyperparameters, initialize train and test generators, and train our model.
Conclusion
In this post, we’ve explored how to use global pooling layers to create and train a CNN that is capable of handling variable image resolution. While we focused on the benefit of resolution variability, global pooling layers provide other benefits that we did not explore, such as nonlinearity and regularization. Consider using global pooling for reasons beyond input shapes.
Keep in mind:
While our model is capable of receiving any image resolution that fits in memory, it is generally a good idea to train on images that are roughly the same size. This is so that the model will encounter artifacts within images that are of similar scale, allowing it to generalize more easily. Similarly, for best results, we want to run inference on images that are close in size to those that were used in training. However, this will depend on the specifics of the use case and the images being used.
Some improvements that could be made to the training data generator used in this tutorial:
1) Random batch image augmentations.
2) Random batch padding styles.
3) Parallel data processing.
Masala.AI
The Mindboard Data Science Team explores cutting-edge technologies in innovative ways to provide original solutions, including the Masala.AI product line. Masala provides media content rating services such as vRate, a browser extension that detects and blocks mature content with custom sensitivity settings. The vRate browser extension is available for download via the Chrome Web Store. Check out www.masala.ai for more info.