Understanding torch.gather function in Pytorch

Pranav Chaturvedi
Analytics Vidhya
Published in
3 min readOct 18, 2020

Two arguments of this function, index and dim are the key to understanding the function.

For case of 2D, dim = 0 corresponds to rows and dim = 1 corresponds to columns.

For case of 3D, dim = 0 corresponds to image from batch, dim = 1 corresponds to rows and dim = 2 corresponds to columns.

Case of 2D input tensor

1. Understanding dim argument:

a. When dim = 0, we choose rows.

b. When dim = 1, we choose columns.

2. Understanding index argument:

a. Index argument will have same no of dimensions as input(does not mean shape will be same).

b. Output tensor will have the same shape as index tensor.

c. The elements of index tensor tell which row (for dim = 0, 2D case) to choose and position of the particular element tells which column to choose.

d. The elements of index tensor tell which column (for dim = 1, 2D case) to choose and position of the particular element tells which row to choose.

Case of 3D input tensor

1. Understanding dim argument:

a. When dim = 0, we choose image from batch.

b. When dim = 1, we choose rows.

c. When dim = 2, we choose columns.

2. Understanding index argument:

a. a and b from above, Case of 2D input tensor apply.

b. The elements of index tensor tell which image from batch (for dim = 0, 3D case) to choose and position of the particular element tells which rows and columns to choose and so on for dim = 1 and dim = 2.

Let’s take two examples for case of 2D.

Example for the case of 2D input tensor

1st Example

When dim = 0, and

ind_2d = [[3, 2, 0, 1]]

ind_2d has shape (1, 4) so output will have same shape.

0th element of ind_2d, i.e. 3 tells we choose 3rd row and 0th column (since 3 is 0th element of index tensor).

1st element of ind_2d i.e. 2 tells we choose 2nd row starting from 0 (row because dim = 0) and 1st column (since 1st element of index tensor). And so on.

2nd Example

Let’s suppose from above src_2d tensor we want to select 0, 6, 10 and 15. (Tip: We read these number order-wise from up to down, so we form a column like index tensor. When we read from left to right, we form row like index tensor.)

Now 0 belongs to 0th column, 6 belongs to 2nd column, 10 belongs to 2nd column and 15 belongs to 3rd column.

So our index tensor is [[0, 2, 2, 3]] in column form i.e. of shape (4, 1). And since we’ve selected columns, therefore dim = 1.

There’s no way we could have selected rows and got the desired output tensor.

The case of 3D is very similar. Get the ideas from infographic below.

Example for the case of 3D input tensor

Here is the link to a python notebook with several examples for both cases of 2D and 3D input tensor:

Link to the Pytorch function :

https://pytorch.org/docs/stable/generated/torch.gather.html?highlight=torch%20gather#torch.gather

--

--

Analytics Vidhya
Analytics Vidhya

Published in Analytics Vidhya

Analytics Vidhya is a community of Generative AI and Data Science professionals. We are building the next-gen data science ecosystem https://www.analyticsvidhya.com

Pranav Chaturvedi
Pranav Chaturvedi

Written by Pranav Chaturvedi

Artificial Intelligence and Machine Learning