Automatic Augmentation Search with Low Computational Effort

Overview

Data augmentation is a widely adopted technique to improve generalization [3] by applying transformations (or sets of combined transformations) during training and optionally at test time. Often practitioners look for the best set of transformations manually, either by relying on domain expertise or making assumptions about generally useful transformations. This process of manual search for optimal augmentations can be time consuming and compute intensive.

The fast.ai library does a great job in providing smart defaults in other areas like hyperparameters, and augmentations are no exception. After a lot of experimentation, the fast.ai team has found a standard set of augmentations which is applied indiscriminately to every dataset, and has proven to work effectively most of the time. However, it becomes evident that this simplification leaves a lot of room for improvement.

Over the last months, our team at fellowship.ai has been working to create a method to automate the search for the best augmentation set, in a computationally efficient manner, and with as little domain-specific input as possible. The purpose of this research aligns with the goal of platform.ai, which seeks to offer deep learning to an extended (non-technical) audience.

Current approaches

There are a few approaches to automatically find a set of transforms, but most of them focus on directly generating the augmented Data. A disadvantage of this approach is that it would not work on a subset, but it needs to be done on the whole dataset, which is not computationally cheap. Also, the augmentation search needs to be done again if the training set has changed.

Alex Ratner and Henry Ehrenberg came up with Transformation Adversarial Networks for Data Augmentation, where they use GANs to find the set of transformation parameters to create augmented data which lies within a defined distribution of interest, which should be representative of the training set.

Also, Cubuk and Zoph from Google created AutoAugment [1], where they use a Reinforcement Learning search to find the best policy (set of transformations) for a certain dataset.

We are looking for a more computationally efficient and domain-agnostic approach for automatically find a set of parameters for the random image transform functions provided by the fast.ai library.

Methodology

The fastai library provides different types of image transformations which are used for data augmentation. These transformations can be grouped in affine, pixel, cropping, lightning and coordinate transformations. The default augmentation set, which is obtained by calling get_transforms() has the following parameters:

Table 1: List of transformations available in fast.ai. There are more transforms available in their documentation, but we focused on the ones available in this list.

Method Validation

The performance of the training set augmented with the chosen transformations will be compared to the performance of the training set augmented with the default fast.ai augmentation set, by training both networks for a determined routine. This will be done for multiple different datasets. The performance metric to consider will be the error rate on the validation set.

Datasets

Different dataset from different domains were selected to prove the effectiveness of this method across domains. A group of dataset, which could be representative of very different use cases was gathered and the following 6 were selected.

TTA Search

We can see from the image below that there Test Time Augmentation can be a strong indicator of when a certain augmentation is a good candidate for training.

Figure 1: Normalized error rate of a trained dataset with a certain augmentation vs normalized TTA for that augmentation as well. The normalization is the relative change to the baseline case. Each dot is a different augmentation tried on the network, each color is an experiment on a different dataset.

We identified that very harmful transformations for training, which would be characterized by a high err/err_none had a very TTA_err/TTA_err_none value. For the Pets and Dogs dataset, for example, it was the dihedral transform, while for Planet it was resize_crop. Many transforms of our set were detrimental for CIFAR-10.

Based on the observed behaviour, we search for a certain augmentation set using TTA, in a procedure which works as follows.

  1. Split the training set into two subsets of size 80% and 20%, respectively.
  2. Train the last layer group on 80% of the training set for EPOCHS_HEAD epochs, without any data augmentation.
  3. Calculate the error rate ERR_NONE on the remaining 20% of the training set.
  4. For each kind of transformation, for each possible magnitude, calculate the TTA error rate on the remaining 20% of the training set. For TTA, we base predictions on WEIGHT_UNTRANSFORMED * LOGITS_UNTRANSFORMED + (1 - WEIGHT_UNTRANSFORMED) * LOGITS_TRANSFORMED. Where WEIGHT_UNTRANSFORMED describes the amount of influence the augmentation has on the prediction.
  5. For each kind of transformation, choose the magnitude which leads to the lowest TTA error rate, if that error rate is lower than THRESHOLD * ERR_NONE; otherwise, don't include that kind of transformation in the final set of augmentations.
  6. With the chosen set of augmentations, train the head for EPOCHS_HEAD epochs and the full network for EPOCHS_FULL.
  7. As a baseline, train the network for the same number of epochs using the transforms provided by get_transforms().

Out of the transformations available in the fast.ai library, we have tested our method with the following transforms/parameters listhttps://www.dropbox.com/s/49btav3j76sluvp/Screenshot%202019-01-27%2005.01.29.png?dl=0

Table 2: List of tested augmentations using our search method, including a list of the tested parameters. * is the probability with which each transform was used for final training. For the TTA search method the probablility was 1.0.

Findings

We have found that the Test Time Augmentation for transformation delivers information about the performance improvement with a particular augmentation, and helps rapidly deciding on image transformations which are constructive for better generalizing the network.

The following results show the performance improvement by this method, in comparison to the baseline case.

Table 3: Top-1 error rates for the found augmentation sets.
UPDATE: The platform.ai team has just reached SOTA performance in the dataset Food-101 using our augmentation search technique.

The list of augmentations picked out for each dataset are detailed below.

Table 4 : List of selected augmentations and its parameters, for each dataset. The values between brackets represent uniform distribution of the RandTransform class.

It is worth noting that the augmentations picked by out method seemed qualitatively reasonable. For example, for the Planet dataset it chooses dihedral flips (which might include upside-down flips), while for Kuzushiji-MNIST it chooses neither left-right nor dihedral flips, since any of these flips would be damaging for the transformation.

With more time, it may be worth investigating how consistent the differences between the error rates are. The differences between the TTA-based selection of augmentations and get_transforms() might be smaller than the differences between different runs for the same set of augmentations.

Acknowledgement

We would to thank Jeremy Howard for his insightful and experience guidance. Our gratitude goes as well to Arshak Navruzyan and David Tedaldi for their support.

References

  1. Cubuk, Ekin D., et al. AutoAugment: Learning Augmentation Policies from Data. arXiv preprint arXiv:1805.09501 (2018).
  2. Geng, Mingyang, et al. Learning data augmentation policies using augmented random search. arXiv preprint arXiv:1811.04768 (2018).
  3. Perez, Luis, et al. The Effectiveness of Data Augmentation in Image Classification using Deep Learning. arXiv preprint arXiv:1712.04621 (2017).
  4. Krizhevsky, Alex, et al. ImageNet Classification with Deep Convolutional Neural Networks. Neural Information Processing Systems. 25. 10.1145/3065386 (2012).

Originally published at HI, I AM CRISTIAN DUGUET.