Understanding indexing with pytorch gather
Have you ever tried to use pytorch gather function? I did, and it was waaaay to difficult. The function itself is pretty useful — but getting how to use it can be pain.
So, what is the purpose of gather function? Docs says:
torch.gather(input, dim, index, out=None, sparse_grad=False) → TensorGathers values along an axis specified by dim.
So, it gathers values along axis. But how does it differ to regular indexing? When using []
operator, you select same index in every place. Consider 4x6 tensor (4 is for batch size, 6 is for features). When you do x[_,:]
or x[:, _]
you select same index in every batch/feature
But imagine following situation: You like to select 3rd feature from 0th example, 7th feature from 1st example, 4th from 3rd and 1st from 4th.
You might think of:
indices = torch.LongTensor([3,7,4,1])
x[:, indices]
But you’ll get:
tensor([[ 3, 7, 4, 1],
[13, 17, 14, 11],
[23, 27, 24, 21],
[33, 37, 34, 31]])
Ok, we need gather function.
Gather requires three parameters:
- input — input tensor
- dim — dimension along to collect values