Shuffle the Batched or Batch the Shuffled, this is the question!
TensorFlow Dataset API
Dataset API is provided by TensorFlow allowing developers to work with data of all sizes in a uniform way.
In this work, it is required first to construct a printing function that will be used to displaying the datasets. The function has the quantity elements which will default to 5.
def printDs(ds, quantity= 5):
print('---------------------')
for example in ds.take(quantity):
print(example)
Now, we can set a up a set of data to use, using python range() function we can create a list of numbers from 0 to 99.
data = list(range(100))
ds = tf.data.Dataset.from_tensor_slices(data)
Then using the from_tensor_slices() saves the list into TensorFlow dataset.
Now we can print a sample of the ds, it is a list of tensors starting from 0, and the function is using its default quantity value which is (5).
Shuffle
This function shuffles the tensors, but more importantly focusing on the “buffer_size”
shuffled = ds.shuffle(buffer_size=5)
printDs(shuffled,50)
Notice that the tensors values are low at the beginning, even though the total dataset range is 100 elements. This is because of the buffer size is 5.
TensorFlow scroll down through the tensors by a window of 5 elements, and then shuffles them. And that is why as we go down in the list, the numbers keeps increasing.
Also, we can try with a shuffle buffer size to be 1, which is shuffling single element with itself and resulting the same as original list.
Note: Shuffle does not save the results back to the dataset, at each time we call the dataset, the shuffle function executed on the dataset. Below example shows printing the dataset twice, and getting 2 different results.
shuffled = ds.shuffle(buffer_size=5)
printDs(shuffled,10) # print first time
printDs(shuffled,10) # print second time
To save the shuffled data we can use “reshuffle_each_iteration = False” , in this case the shuffled results are saved back to the dataset as per below.
shuffled = ds.shuffle(buffer_size=5, reshuffle_each_iteration = False)
printDs(shuffled,10) # print first time
printDs(shuffled,10) # print second time
Batching
When testing, usually we sent a group of data to the model instead of sending a single value. Batching groups the tensors to be passed to the model as groups. The Batch function takes the size of the parameters which how many tensors we need to pass to the model.
batched = ds.batch(10)
printDs(batched)
Now in some cases, we might end up with some reminders. For example if we batched the 100 elements into batch size = 14, looking to the last tensor we have 2 elements only.
We can get red of the reminders using drop_remainder=True
batched = ds.batch(14, drop_remainder=True)
printDs(batched,10)
Shuffle the batched
If we add the shuffle API on top of the batched data as follows:
Shuffle_batched = ds.batch(14, drop_remainder=True).shuffle(buffer_size=5)
printDs(Shuffle_batched,10)
The output as you can see batches are not in order, but the content of each batch is in order. This is because the batch command comes first, and then the shuffle was implemented on the batches level.
Batch the shuffled
Alternatively, if we shuffle first the batch we will get batches in order but each batch has its own shuffled elements.
Batch_shuffled = ds.shuffle(buffer_size=5).batch(14, drop_remainder=True)printDs(Batch_shuffled,10)
Combining all
To cover all cases, we can shuffle a shuffled batches:
shuffle_Batch_shuffled = ds.shuffle(buffer_size=5).batch(14, drop_remainder=True).shuffle(buffer_size=50)
printDs(shuffle_Batch_shuffled,10)
In this case each batch hast its own shuffled elements, and the dataset has its own shuffled batches.
We can have now a complete code that stops reshuffling at each time we call the dataset ds
myFinalDs = ds.shuffle(buffer_size=5, reshuffle_each_iteration = False ).batch(14, drop_remainder=True).shuffle(buffer_size=50, reshuffle_each_iteration=False)printDs(myFinalDs,10)printDs(myFinalDs,10)
With this, we have a fixed dataset that has shuffled elements and its own shuffled batches.
Conclusion
The shuffle API and the batch API are very helpful tools when preparing data before consuming it into the TensorFlow models. There is no right or wrong, each problem has its own requirements. The above explanation shows the fixabilities of the batch and shuffle dataset API.
Full Code