Data Augmentation to solve imbalanced training data for Image Classification
This article will walk you through how one can use Data Augmentation to solve the problem of having imbalanced image classification data. Having imbalanced training data can lead to bias in the classifier, in scenarios where it’s not feasible to get more training data for under represented classes, Data Augmentation can be used to increase the size of training data.
In this article, I go over a few techniques that can be used to augment training data for imbalanced classes.
First, let’s read the actual image.
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import randomimage = Image.open("odin.jpg")
img = np.array(image)
plt.imshow(img)
plt.show()
We can perform several transformation on the original image to augment our training data, such as,
- Flipping the image horizontally
flipped_img = np.fliplr(img)
Image.fromarray(flipped_img)
2. Rotating the image by a random degree
random_degree = random.uniform(-50, 50)
image.rotate(random_degree)
3. Adding random noise to the Image
vals = len(np.unique(image))
vals = 0.9 ** np.ceil(np.log2(vals))
noisy = np.random.poisson(image * vals) / float(vals)
plt.imshow(noisy.astype('uint8'))
plt.show()
4. Cropping the image
width, height = image.size
left = 0
top = height/4
right = width*0.9
bottom = 2*height/2image = image.crop((left, top, right, bottom))plt.imshow(image)
plt.show()
Keras also provides a simple and effective method that can be used for Data Augmentation (Link) via the keras.preprocessing.image.ImageDataGenerator
class. This class allows you to:
- configure random transformations and normalisation operations to be done on your image data during training
- instantiate generators of augmented image batches (and their labels) via
.flow(data, labels)
or.flow_from_directory(directory)
. These generators can then be used with the Keras model methods that accept data generators as inputs,fit_generator
,evaluate_generator
andpredict_generator
.
from keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
Using Data Augmentation we can quickly increase the amount of data for our imbalanced classes, this will ensure that our model does not see the same image twice and helps avoid overfitting and aids the model to generalise better.