Understanding indexing with pytorch gather
Have you ever tried to use pytorch gather function? I did, and it was waaaay to difficult. The function itself is pretty useful — but getting how to use it can be pain.
So, what is the purpose of gather function? Docs says:
torch.gather(input, dim, index, out=None, sparse_grad=False) → TensorGathers values along an axis specified by dim.
So, it gathers values along axis. But how does it differ to regular indexing? When using
 operator, you select same index in every place…