Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is column selection in pytorch differentiable?

Is column selection in Pytorch differentiable? for eg if I want to select a single column from each row to make a new row X 1 array and then backdrop using this new array, will the backdrop work properly?

qvalues = qvalues[range(5),[0,1,0,1,0]]

if element selection is done as shown above from a 5*2 tensor?

like image 863
patrick Avatar asked Jul 16 '18 12:07

patrick


1 Answers

I think it is. Let me make an example with code.

First we create the qvalues tensor and we say we want to compute its gradients

qvalues = torch.rand((5, 5), requires_grad=True)

Now we create the tensor to index it and obtain a 5x2 tensor as its result (I think this is the same selection you wanted to perform with qvalues[range(5),[0,1,0,1,0]]):

y = torch.LongTensor([1, 3])
new_qvalues = qvalues[:, y]

We see that the slice new_qvalues of the original qvalues will compute the gradient

print(new_qvalues.requires_grad) # True

Now we perform our mathematical operations. In this example code, I am doing the square of new_qvalues because we know that its gradient (derivative) will be 2 * new_qvalues.

qvalues_a = new_qvalues ** 2

Now, we have to compute the gradients of qvalues_a. We set retain_graph=True to store the .grad of each tensor and avoid freeing the buffers on the backward pass.

qvalues_a.backward(torch.ones(new_qvalues.shape), retain_graph=True)

Now, we can go back to the original qvalues and see if the gradients have been calculated

print(qvalues)
print(qvalues.grad)

# result of the print statemets

#tensor([[ 0.9677,  0.4303,  0.2036,  0.3870,  0.6085],
#    [ 0.8876,  0.8695,  0.2028,  0.3283,  0.1560],
#    [ 0.1764,  0.4718,  0.5418,  0.5167,  0.6200],
#    [ 0.7610,  0.9322,  0.5584,  0.5589,  0.8901],
#    [ 0.8146,  0.7296,  0.8036,  0.5277,  0.5754]])

#tensor([[ 0.0000,  0.8606,  0.0000,  0.7739,  0.0000],
#    [ 0.0000,  1.7390,  0.0000,  0.6567,  0.0000],
#    [ 0.0000,  0.9435,  0.0000,  1.0334,  0.0000],
#    [ 0.0000,  1.8645,  0.0000,  1.1178,  0.0000],
#    [ 0.0000,  1.4592,  0.0000,  1.0554,  0.0000]])

We can observe how the gradients have been computed only in the selected indexes. To be sure about it we create some fast test by comparing that the value of qvalues.grad for the selected slice is equal to the derivate 2 * new_qvalues.

assert torch.equal(qvalues.grad[:, y], 2 * new_qvalues)

And it does not throw any error, so I would assume that you can get the gradient of the slice.

like image 51
Manuel Lagunas Avatar answered Sep 16 '22 22:09

Manuel Lagunas