The dangers of reshaping and other fun mistakes I’ve learnt from PyTorch

There’s a huge disconnect between discussing a potential deep learning architecture and it’s actual implementation — especially when it comes to batch training. While a concept outlined in a paper might seem straightforward to implement, when it comes to actual implementation you find that it’s a bit harder than you realized. Then, you remember that you need to incorporate batches in your training

I’ll be using PyTorch here for examples.

Let’s say you’re just starting out with PyTorch, and you’re working on a language model of some sort. You want an encoded representation of an input sequence. You want to capture dependencies from forward and backwards directions of your sequence, so you decide to encode your sentence with a bidirectional RNN of your choosing.

Pretty standard encoder so far. So what sizes are the outputs and the hidden state? If we refer to the PyTorch documentation, we see that with batch_first, our outputs should be:

(batch, seq_len, num_directions * hidden_size)

And our final hidden state size should be:

(num_layers * num_directions, batch, hidden_size)

For SOME reason, PyTorch decided to have different shapes for the (stacked) hidden states at every time step and the final hidden state, and return you two tensors (one for each direction) for final hidden state.

No worries, we can pretty simply take care of this by reshaping our hidden state:

Cool! So this gives the right size… But something smells very very off.

Batch interference and the dangers of reshaping

If we dig a little deeper as to how tensors are actually stored in PyTorch, we see that PyTorch (and most tensor libraries) store multi-dimensional tensors in a single, contiguous block within memory — so basically a 1-D array with pointers to those elements that represent a dimensional mapping of the offsets for each dimension. We can see how this works with this example:

Here, we create a tensor to mimic the return of the final hidden state of a bidirectional RNN (num_layers * num_directions, batch, hidden_size). Our 1’s and 0.1’s are in one batch, 2’s and 0.2’s in the second and 3’s and 0.3’s in our 3rd batch.

Here we have the (fake) return of the final RNN hidden state — the tensor is a stack of the forwards final hidden state and backwards final hidden state. Each hidden state has a batch size of 3, and each hidden size is 5.

What we want to do here is end up with a tensor of size (batch, hidden_size * num_directions) , which means ‘zipping’ each row in both directions with each other. We want to match all the 1’s with the 0.1’s, and so on.

Let’s see what our naive approach gave us:

all stirred up

This is an incredibly careless use of theview function! We tell our tensor our 0th dimensional offset should be batch_size , and to put the rest in the second dimension. PyTorch split our single contiguous array into 3 equal batches, from beginning to end. This resulted in batch interference!

Instead, what we actually want to do is first to transpose our first and second dimension:

This allows us to swap the directions and batch dimensions, and maintains the correct offsets. This does exactly what transposing a 2-D matrix does, except with one extra dimension.

We also need to call contiguous on this new tensor exactly because of how PyTorch stores tensors. transpose-ing a tensor doesn’t mean we change the contiguous memory location where the tensor was originally stored — instead it just rearranges the pointers to that same memory location (hence the name view: because we’re just looking at this memory location from another viewpoint). This makes normal tensor operations incredibly fast, but also means if we change our viewpoint of this memory block and want this specific viewpoint as a block of contiguous memory, we need to call contiguous.

So now we just need to flatten our last two dimensions — this seems like something that we’ve seen before:

And there we go — we have all our numbers with 1’s in our first batch, our numbers with 2’s in our second batch, and our numbers with 3’s in our third batch.

The lessons taught here apply to all tensor libraries! numpy ‘s reshape does a very similar operation to PyTorch’s view , so the same lessons apply there too.