Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why does tf.matmul(a,b, transpose_b=True) work, but not tf.matmul(a, tf.transpose(b))?

Code:

x = tf.constant([1.,2.,3.], shape = (3,2,4))
y = tf.constant([1.,2.,3.], shape = (3,21,4))
tf.matmul(x,y)                     # Doesn't work. 
tf.matmul(x,y,transpose_b = True)  # This works. Shape is (3,2,21)
tf.matmul(x,tf.transpose(y))       # Doesn't work.

I want to know what shape y becomes inside tf.matmul(x,y,transpose_b = True) so I can work out what is really going on inside an LSTM with attention.

like image 506
user3933614 Avatar asked Jan 04 '18 17:01

user3933614


People also ask

How does TF transpose work?

transpose(x, perm=[1, 0]) . As above, simply calling tf. transpose will default to perm=[2,1,0] . To take the transpose of the matrices in dimension-0 (such as when you are transposing matrices where 0 is the batch dimension), you would set perm=[0,2,1] .

What is TensorFlow Matmul?

This is nothing but the matrix multiplication, it is achieved by using "linalg. matmul" function available in tensorflow. It will returns the multiplication of matrix for e.g matrix "a" by matrix "b" will produce a * b.

How do you transpose a tensor?

To transpose a tensor, we need two dimensions to be transposed. If a tensor is 0-D or 1-D tensor, the transpose of the tensor is same as is. For a 2-D tensor, the transpose is computed using the two dimensions 0 and 1 as transpose(input, 0, 1).

Which operator is used to perform matrix multiplication in TensorFlow?

To perform element-wise multiplication, you should use the tf. multiply() method. To perform matrix multiplication, you should use the tf. matmul() method.


1 Answers

Transpose can be defined differently for tensors of rank > 2, and here the difference is in axes that are transposed by tf.transpose and tf.matmul(..., transpose_b=True).

By default, tf.transpose does this:

The returned tensor's dimension i will correspond to the input dimension perm[i]. If perm is not given, it is set to (n-1...0), where n is the rank of the input tensor. Hence by default, this operation performs a regular matrix transpose on 2-D input Tensors.

So in your case, it's going to transform y into a tensor of shape (4, 21, 3), which is not compatible with x (see below).

But if you set perm=[0, 2, 1], the result is compatible:

# Works! (3, 2, 4) * (3, 4, 21) -> (3, 2, 21).
tf.matmul(x, tf.transpose(y, [0, 2, 1]))

About tf.matmul

You can compute the dot product: (a, b, c) * (a, c, d) -> (a, b, d). But it's not tensor dot product -- it's a batch operation (see this question).

In this case, a is considered a batch size, so tf.matmul computes a dot-products of matrices (b, c) * (c, d).

Batch can be more than one dimension, so this is also valid:

(a, b, c, d) * (a, b, d, e) -> (a, b, c, e)
like image 144
Maxim Avatar answered Nov 15 '22 05:11

Maxim