Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to apply the torch.inverse() function of PyTorch to every sample in the batch?

Tags:

python

pytorch

This may seem like a basic question, but I am unable to work it through.

In the forward pass of my neural network, I have an output tensor of shape 8x3x3, where 8 is my batch size. We can assume each 3x3 tensor to be a non-singular matrix. I need to find the inverse of these matrices. The PyTorch inverse() function only works on square matrices. Since I now have 8x3x3, how do I apply this function to every matrix in the batch in a differentiable manner?

If I iterate through the samples and append the inverses to a python list, which I then convert to a PyTorch tensor, should it be a problem during backprop? (I am asking since converting PyTorch tensors to numpy to perform some operations and then back to a tensor won't compute gradients during backprop for such operations)

I also get the following error when I try to do something like that.

a = torch.arange(0,8).view(-1,2,2)
b = [m.inverse() for m in a]
c = torch.FloatTensor(b)

TypeError: 'torch.FloatTensor' object does not support indexing

like image 670
phoenixwing Avatar asked Oct 05 '17 21:10

phoenixwing


People also ask

How do you invert a torch Tensor?

Literally just use the tilde to transform all True into False and vice versa.

What does * do in PyTorch?

For . view() pytorch expects the new shape to be provided by individual int arguments (represented in the doc as *shape ). The asterisk ( * ) can be used in python to unpack a list into its individual elements, thus passing to view the correct form of input arguments it expects.

How do you multiply tensors in PyTorch?

mul() method is used to perform element-wise multiplication on tensors in PyTorch. It multiplies the corresponding elements of the tensors. We can multiply two or more tensors. We can also multiply scalar and tensors.

How do you transpose Tensor PyTorch?

The transpose is obtained by changing the rows to columns and columns to rows. we can transpose a tensor by using transpose() method.


1 Answers

EDIT:

As of Pytorch version 1.0, torch.inverse now supports batches of tensors. See here. So you can simply use the built-in function torch.inverse

OLD ANSWER

There are plans to implement batched inverse soon. For discussion, see for example issue 7500 or issue 9102. However, as of the time of writing, the current stable version (0.4.1), no batch inverse operation is available.

Having said that, recently batch support for torch.gesv was added. This can be (ab)used to define your own batched inverse operation along the following lines:

def b_inv(b_mat):
    eye = b_mat.new_ones(b_mat.size(-1)).diag().expand_as(b_mat)
    b_inv, _ = torch.gesv(eye, b_mat)
    return b_inv

I found that this gives good speed-ups over a for loop when running on GPU.

like image 130
mbpaulus Avatar answered Oct 18 '22 15:10

mbpaulus