Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What does the gather function do in pytorch in layman terms?

Tags:

pytorch

I have been through the official doc and this but it is hard to understand what is going on.

I am trying to understand a DQN source code and it uses the gather function on line 197.

Could someone explain in simple terms what the gather function does? What is the purpose of that function?

like image 353
amitection Avatar asked Jun 23 '18 09:06

amitection


People also ask

What does .view do in PyTorch?

PyTorch allows a tensor to be a View of an existing tensor. View tensor shares the same underlying data with its base tensor. Supporting View avoids explicit data copy, thus allows us to do fast and memory efficient reshaping, slicing and element-wise operations.

What is the function of PyTorch?

PyTorch: Defining new autograd functionsThe forward function computes output Tensors from input Tensors. The backward function receives the gradient of the output Tensors with respect to some scalar value, and computes the gradient of the input Tensors with respect to that same scalar value.

What is gather in PyTorch?

In deep learning we need to extract the values from the specified columns of the matrix at that time we can use the Pytorch gather() function. In other words, we can say that by using PyTorch gather we can create a new tensor from specified input tensor values from each row with specified input dimension.

What does * do in PyTorch?

What is * ? For . view() pytorch expects the new shape to be provided by individual int arguments (represented in the doc as *shape ). The asterisk ( * ) can be used in python to unpack a list into its individual elements, thus passing to view the correct form of input arguments it expects.


1 Answers

torch.gather creates a new tensor from the input tensor by taking the values from each row along the input dimension dim. The values in torch.LongTensor, passed as index, specify which value to take from each 'row'. The dimension of the output tensor is same as the dimension of index tensor. Following illustration from the official docs explains it more clearly: Pictoral representation from the docs

(Note: In the illustration, indexing starts from 1 and not 0).

In first example, the dimension given is along rows (top to bottom), so for (1,1) position of result, it takes row value from the index for the src that is 1. At (1,1) in source value is 1 so, outputs 1 at (1,1) in result. Similarly for (2,2) the row value from the index for src is 3. At (3,2) the value in src is 8 and hence outputs 8 and so on.

Similarly for second example, indexing is along columns, and hence at (2,2) position of the result, the column value from the index for src is 3, so at (2,3) from src ,6 is taken and outputs to result at (2,2)

like image 136
Ritesh Avatar answered Sep 28 '22 00:09

Ritesh