Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Batch-Matrix multiplication in Pytorch - Confused with the handling of the output's dimension

I got two arrays :

A
B

Array A contains a batch of RGB images, with shape:

[batch, Width, Height, 3]

whereas Array B contains coefficients needed for a "transformation-like" operation on images, with shape:

[batch, 4, 4, 3]

To put it simply, the operation for a single image is a multiplication that outputs an environment map (normalMap * Coefficients).

The output I want should hold shape:

[batch, Width, Height, 3]

I tried using torch.bmm but failed. Is this possible somehow?

like image 537
singa1994 Avatar asked Jun 11 '19 12:06

singa1994


People also ask

What is batch matrix multiplication PyTorch?

PyTorch bmm is used for matrix multiplication in batches where the scenario involves that the matrices to be multiplied have the size of 3 dimensions that is x, y, and z and the dimension of the first dimension for matrices to be multiplied should be the same.

Can you multiply matrix with different dimensions?

You can only multiply two matrices if their dimensions are compatible , which means the number of columns in the first matrix is the same as the number of rows in the second matrix.

How do you do matrix multiplication in PyTorch?

For matrix multiplication in PyTorch, use torch.mm() . Numpy's np. dot() in contrast is more flexible; it computes the inner product for 1D arrays and performs matrix multiplication for 2D arrays.

How do you multiply dimensions of a matrix?

In order for matrix multiplication to be defined, the number of columns in the first matrix must be equal to the number of rows in the second matrix. To find A B AB AB , we take the dot product of a row in A and a column in B.


1 Answers

I think you need to calculate that PyTorch works with

BxCxHxW : number of mini-batches, channels, height, width

format, and also use matmul, since bmm works with tensors or ndim/dim/rank =3.

I know you may find this online, but for any case:

batch1 = torch.randn(10, 3, 20, 10)
batch2 = torch.randn(10, 3, 10, 30)
res = torch.matmul(batch1, batch2)
res.size() # torch.Size([10, 3, 20, 30])
like image 154
prosti Avatar answered Sep 29 '22 04:09

prosti