Shuffle the Batched or Batch the Shuffled, this is the question!

Ashraf Dasa
5 min readOct 12, 2021

--

TensorFlow Dataset API

Dataset API is provided by TensorFlow allowing developers to work with data of all sizes in a uniform way.

In this work, it is required first to construct a printing function that will be used to displaying the datasets. The function has the quantity elements which will default to 5.

def printDs(ds, quantity= 5):
print('---------------------')
for example in ds.take(quantity):
print(example)

Now, we can set a up a set of data to use, using python range() function we can create a list of numbers from 0 to 99.

data = list(range(100))
ds = tf.data.Dataset.from_tensor_slices(data)

Then using the from_tensor_slices() saves the list into TensorFlow dataset.

Now we can print a sample of the ds, it is a list of tensors starting from 0, and the function is using its default quantity value which is (5).

Shuffle

This function shuffles the tensors, but more importantly focusing on the “buffer_size

shuffled = ds.shuffle(buffer_size=5)
printDs(shuffled,50)

Notice that the tensors values are low at the beginning, even though the total dataset range is 100 elements. This is because of the buffer size is 5.

TensorFlow scroll down through the tensors by a window of 5 elements, and then shuffles them. And that is why as we go down in the list, the numbers keeps increasing.

Also, we can try with a shuffle buffer size to be 1, which is shuffling single element with itself and resulting the same as original list.

Note: Shuffle does not save the results back to the dataset, at each time we call the dataset, the shuffle function executed on the dataset. Below example shows printing the dataset twice, and getting 2 different results.

shuffled = ds.shuffle(buffer_size=5)
printDs(shuffled,10) # print first time

printDs(shuffled,10) # print second time

To save the shuffled data we can use “reshuffle_each_iteration = False” , in this case the shuffled results are saved back to the dataset as per below.

shuffled = ds.shuffle(buffer_size=5, reshuffle_each_iteration = False)
printDs(shuffled,10) # print first time

printDs(shuffled,10) # print second time

Batching

When testing, usually we sent a group of data to the model instead of sending a single value. Batching groups the tensors to be passed to the model as groups. The Batch function takes the size of the parameters which how many tensors we need to pass to the model.

batched = ds.batch(10)
printDs(batched)

Now in some cases, we might end up with some reminders. For example if we batched the 100 elements into batch size = 14, looking to the last tensor we have 2 elements only.

We can get red of the reminders using drop_remainder=True

batched = ds.batch(14, drop_remainder=True)
printDs(batched,10)

Shuffle the batched

If we add the shuffle API on top of the batched data as follows:

Shuffle_batched = ds.batch(14, drop_remainder=True).shuffle(buffer_size=5)
printDs(Shuffle_batched,10)

The output as you can see batches are not in order, but the content of each batch is in order. This is because the batch command comes first, and then the shuffle was implemented on the batches level.

Batch the shuffled

Alternatively, if we shuffle first the batch we will get batches in order but each batch has its own shuffled elements.

Batch_shuffled = ds.shuffle(buffer_size=5).batch(14, drop_remainder=True)printDs(Batch_shuffled,10)

Combining all

To cover all cases, we can shuffle a shuffled batches:

shuffle_Batch_shuffled = ds.shuffle(buffer_size=5).batch(14, drop_remainder=True).shuffle(buffer_size=50)
printDs(shuffle_Batch_shuffled,10)

In this case each batch hast its own shuffled elements, and the dataset has its own shuffled batches.

We can have now a complete code that stops reshuffling at each time we call the dataset ds

myFinalDs = ds.shuffle(buffer_size=5, reshuffle_each_iteration = False ).batch(14, drop_remainder=True).shuffle(buffer_size=50,  reshuffle_each_iteration=False)printDs(myFinalDs,10)printDs(myFinalDs,10)

With this, we have a fixed dataset that has shuffled elements and its own shuffled batches.

Conclusion

The shuffle API and the batch API are very helpful tools when preparing data before consuming it into the TensorFlow models. There is no right or wrong, each problem has its own requirements. The above explanation shows the fixabilities of the batch and shuffle dataset API.

Full Code

--

--