Is column selection in Pytorch differentiable? for eg if I want to select a single column from each row to make a new row X 1 array and then backdrop using this new array, will the backdrop work properly?
qvalues = qvalues[range(5),[0,1,0,1,0]]
if element selection is done as shown above from a 5*2 tensor?
I think it is. Let me make an example with code.
First we create the qvalues
tensor and we say we want to compute its gradients
qvalues = torch.rand((5, 5), requires_grad=True)
Now we create the tensor to index it and obtain a 5x2 tensor as its result (I think this is the same selection you wanted to perform with qvalues[range(5),[0,1,0,1,0]]
):
y = torch.LongTensor([1, 3])
new_qvalues = qvalues[:, y]
We see that the slice new_qvalues
of the original qvalues
will compute the gradient
print(new_qvalues.requires_grad) # True
Now we perform our mathematical operations. In this example code, I am doing the square of new_qvalues
because we know that its gradient (derivative) will be 2 * new_qvalues
.
qvalues_a = new_qvalues ** 2
Now, we have to compute the gradients of qvalues_a
. We set retain_graph=True
to store the .grad
of each tensor and avoid freeing the buffers on the backward pass.
qvalues_a.backward(torch.ones(new_qvalues.shape), retain_graph=True)
Now, we can go back to the original qvalues
and see if the gradients have been calculated
print(qvalues)
print(qvalues.grad)
# result of the print statemets
#tensor([[ 0.9677, 0.4303, 0.2036, 0.3870, 0.6085],
# [ 0.8876, 0.8695, 0.2028, 0.3283, 0.1560],
# [ 0.1764, 0.4718, 0.5418, 0.5167, 0.6200],
# [ 0.7610, 0.9322, 0.5584, 0.5589, 0.8901],
# [ 0.8146, 0.7296, 0.8036, 0.5277, 0.5754]])
#tensor([[ 0.0000, 0.8606, 0.0000, 0.7739, 0.0000],
# [ 0.0000, 1.7390, 0.0000, 0.6567, 0.0000],
# [ 0.0000, 0.9435, 0.0000, 1.0334, 0.0000],
# [ 0.0000, 1.8645, 0.0000, 1.1178, 0.0000],
# [ 0.0000, 1.4592, 0.0000, 1.0554, 0.0000]])
We can observe how the gradients have been computed only in the selected indexes. To be sure about it we create some fast test by comparing that the value of qvalues.grad
for the selected slice is equal to the derivate 2 * new_qvalues
.
assert torch.equal(qvalues.grad[:, y], 2 * new_qvalues)
And it does not throw any error, so I would assume that you can get the gradient of the slice.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With