Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

3-D batch matrix multiplication without knowing batch size

I'm currently writing a tensorflow program that requires multiplying a batch of 2-D tensors (a 3-D tensor of shape [None,...]) with a 2-D matrix W. This requires turning W into a 3-D matrix, which requires knowing the batch size.

I have not been able to do this; tf.batch_matmul is no longer usable, x.get_shape().as_list()[0] returns None, which is invalid for a reshaping/tiling operation. Any suggestions? I've seen some people use config.cfg.batch_size, but I don't know what that is.

like image 902
Pian Pawakapan Avatar asked Feb 10 '26 19:02

Pian Pawakapan


1 Answers

Solution is to use a combination of tf.shape (which returns the shape at runtime) and tf.tile (which accepts the dynamic shape).

x = tf.placeholder(shape=[None, 2, 3], dtype=tf.float32)
W = tf.Variable(initial_value=np.ones([3, 4]), dtype=tf.float32)
print(x.shape)                # Dynamic shape: (?, 2, 3)

batch_size = tf.shape(x)[0]   # A tensor that gets the batch size at runtime
W_expand = tf.expand_dims(W, axis=0)
W_tile = tf.tile(W_expand, multiples=[batch_size, 1, 1])
result = tf.matmul(x, W_tile) # Can multiply now!

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  feed_dict = {x: np.ones([10, 2, 3])}
  print(sess.run(batch_size, feed_dict=feed_dict))    # 10
  print(sess.run(result, feed_dict=feed_dict).shape)  # (10, 2, 4)
like image 101
Maxim Avatar answered Feb 17 '26 18:02

Maxim



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!