Batching Strategies For LSTM Input

Banjoko Judah
Analytics Vidhya
Published in
5 min readMar 6, 2021

How to do it right in Pytorch

Introduction

A common frustration one gets while working with RNN layers, especially on data with varied sizes, has to do with batching. It is general knowledge that there are many benefits from batching our data but just thinking about how to do it can break one’s spirit.

The thought of the complex and inefficient code that we’ll have to write or source elsewhere that only the author of the code understands…, the horror! At some other times, it is just impossible to get right. Well, I say it does not have to be so.

In this article, I will be sharing some of the strategies I use when batching my data. Though I’ll be using Pytorch here, the same ideas work with other libraries. The complete code used in this article resides in this Github repository.

Overview

I would be introducing the strategies in different phases in our deep-learning process. So we can focus on the topic; I’ll be using the task from a previous article I wrote. Reading it is not required, though!

We will be training our model to generate Dinosaur names; this is the data we will be using. It is a text file that contains each Dinosaur name on a new line. If you need clarity on something I did in this task, I probably already talked about it in the other article.

With that, let’s get started.

1. Vocabulary

The first in our list of strategies has to do with how we define our vocabulary. But before we go into that, we need to preprocess our data. The goal is to convert each Dinosaur name to a list of its characters and append the EOS (End Of Sequence) token to the list. For example, the name Aachenosaurus will be:

["a", "a", "c", "h", "e", "n", "o", "s", "a", "u", "r", "u", "s", "<EOS>"]

After that, we create our vocabulary, which will contain the 26 letters of the English alphabet, the EOS token, and the PAD (Padding) token, giving us a vocabulary size of 28. Then we create two dictionaries; the first one maps each item in our vocabulary to a unique integer, and the second does the reverse.

Now, it is super important that we add the PAD token last when defining our vocabulary. Doing this would makes the PAD index (which is the index of the PAD token) 27. I will explain the reason why we are doing this later on when we define our model.

2. Data Loading

The next (and most obvious) phase when batching our data is getting several samples that are of equal length and then stacking them together. Now, there are many ways of doing this, but we want to keep things straightforward here. What we’ll do is to pad each sample with the PAD index until its length equals that of the longest training sample (X) in our data.

We start by converting each character in our already processed data to integers, and at the same time, we keep track of the length of the longest X. As we’ll soon see, X is 1-less than the actual sample, which is why I subtracted one when updating max_seqlen. The batched argument allows us to skip the entire padding process when we don’t want to batch our data.

Next is the __getitem__ method, where we slice out X and Y from a sample at index ix. Then we append as many PAD indexes that will get their lengths to equal max_seqlen, convert them to tensors, and finally return them. We also store the size of X before padding it. A particular Pytorch method in the future will need it.

We finish by creating a DataLoader object with batch_size 16 and setting shuffle to True. If the for loop above generates an error, then you might be doing something different.

3. Model Definition

In the __init__ method of our model, you would notice that we set the padding_idx parameter of the Embedding layer to the PAD index. What this does is that it would make the embedding at PAD index a vector of zeros and also ensure that its gradient is always zero. So unlike the other embeddings, it will not change when we perform backward prop.

Notice also that we set the number of output features from the Linear layer to 1-less than the vocabulary size. Now, if you remembered, we deliberately made the PAD token last in our vocabulary. The effect of this is that our model will NEVER predict the PAD index as an answer.

We could also achieve this if we set the PAD index to some other value like zero (as is often seen) but doing so would make the result very confusing to interpret.

4. Forward Pass

Through the forward pass of our model (due to the padding we added), we have to mask the input such that the LSTM layer ignores the padding of each sample. Here, our forward pass is defined entirely in the forward method of our model. It looks similar to how it would when we are not batching, except that we now have two weirdly-named methods that make our lives 100X easier.

The first of these methods, pack_padded_sequence, does the masking for us. It takes our embedding and the true-length of each sample in our batch as input; this is where we use x_len from our Dataset class. It then does the needed transformations and returns a PackedSequence object that the LSTM layer understands.

It is important to note that the output of the LSTM layer is also a PackedSequence object. Now, the Linear layer does not accept this object, so we need to convert it to a tensor first. We do this with the second method, pad_packed_sequence, which does the inverse of the first method.

5. Loss Function

The final piece in our strategy is that, when calculating our loss, we ignore indexes where the target (Y) value is the same as the PAD index. This is so that they do not contribute to our gradient and effectively do not affect our model’s training. The same goes for when we evaluate our model. The padding will add noise that we don’t want if we do not ignore them.

Luckily, Pytorch’s Cross-Entropy loss function can do this for us. All we need to do is set the ignore_index parameter to the PAD index, and that’s it. I also included an example of how we can do this when using a loss function without the ignore_index capability.

Conclusion

If you followed the simple strategies listed here, I’m sure you’d be screaming right now. I should also mention that I tested a single pass of no-batching VS batching, and they both gave the same output and loss. It was as though we never batched our data while getting all the benefits of batching.

The main ideas we discussed are in bold text; you can run through them to get a summary. These same ideas apply when working with other libraries or with different sequential data like audios, videos, signals, etc.

You can find the complete code, including the test codes, in the Github repository I posted at the start of this article.

Please remember to clap, share and follow me if you enjoyed this article. Thanks for reading!

--

--

Banjoko Judah
Analytics Vidhya

I am a writer, Python developer and an AI practitioner. Super interested in AR & VR and can't wait to see the changes they'll bring to our daily lives.