Implement fit_generator( ) in Keras

An Nguyen
2 min readFeb 2, 2017

--

Here is an example of fit_generator():

model.fit_generator(generator(features, labels, batch_size), samples_per_epoch=50, nb_epoch=10)

Breaking it down:

generator(features, labels, batch_size): generates batches of samples indefinitely

sample_per_epoch: number of samples you want to train in each epoch

nb_epoch: number of epochs

As you can manually define sample_per_epoch and nb_epoch , you have to provide codes for generator . Here is an example:

Assume features is an array of data with shape (100,64,64,3) and labels is an array of data with shape (100,1). We use data from features and labels to train our model.

def generator(features, labels, batch_size): # Create empty arrays to contain batch of features and labels# batch_features = np.zeros((batch_size, 64, 64, 3))
batch_labels = np.zeros((batch_size,1))
while True:
for i in range(batch_size):
# choose random index in features
index= random.choice(len(features),1)
batch_features[i] = some_processing(features[index])
batch_labels[i] = labels[index]
yield batch_features, batch_labels

With the generator above, if we define batch_size = 10 , that means it will randomly taking out 10 samples from features and labels to feed into each epoch until an epoch hits 50 sample limit. Then fit_generator() destroys the used data and move on repeating the same process in new epoch.

One great advantage about fit_generator() besides saving memory is user can integrate random augmentation inside the generator, so it will always provide model with new data to train on the fly.

For more information on fit_generator() arguments, refer to Keras website:

I hope you found the content is helpful. If so, please hit ❤ to share and I really appreciate any feedback. Until next time!

--

--