Understanding indexing with pytorch gather

Mateusz Bednarski
5 min readMar 22, 2020

Have you ever tried to use pytorch gather function? I did, and it was waaaay to difficult. The function itself is pretty useful — but getting how to use it can be pain.

So, what is the purpose of gather function? Docs says:

torch.gather(input, dim, index, out=None, sparse_grad=False) → TensorGathers values along an axis specified by dim.

So, it gathers values along axis. But how does it differ to regular indexing? When using [] operator, you select same index in every place. Consider 4x6 tensor (4 is for batch size, 6 is for features). When you do x[_,:] or x[:, _] you select same index in every batch/feature

Selecting either 3rd feature in every example, or every feature from 3rd example

But imagine following situation: You like to select 3rd feature from 0th example, 7th feature from 1st example, 4th from 3rd and 1st from 4th.

You might think of:

indices = torch.LongTensor([3,7,4,1])
x[:, indices]

But you’ll get:

tensor([[ 3,  7,  4,  1],
[13, 17, 14, 11],
[23, 27, 24, 21],
[33, 37, 34, 31]])

Ok, we need gather function.

Gather requires three parameters:

  • input — input tensor
  • dim — dimension along to collect values

--

--

Mateusz Bednarski

AI enthusiast. Focused mostly on NLP and good software engineering practices for machine learning projects. Currently working at Roche.