Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow: how to batch mut-mul a batch tensor by a weight variable?

I have the following batch shape:

 [?,227,227]

And the following weight variable:

 weight_tensor = tf.truncated_normal([227,227],**{'stddev':0.1,'mean':0.0})

 weight_var = tf.Variable(weight_tensor)

But when I do tf.batch_matmul:

 matrix = tf.batch_matmul(prev_net_2d,weight_var)

I fail with the following error:

ValueError: Shapes (?,) and () must have the same rank


So my question becomes: How do I do this?

How do I just have a weight_variable in 2D that gets multiplied by each individual picture (227x227) so that I have a (227x227) output?? The flat version of this operation completely exhausts the resources...plus the gradient won't change the weights correctly in the flat form...


Alternatively: how do I split the incoming tensor along the batch dimension (?,) so that I can run the tf.matmul function on each of the split tensors with my weight_variable?

like image 476
Chris Avatar asked Apr 26 '16 16:04

Chris


1 Answers

You could tile weights along the first dimension

weight_tensor = tf.truncated_normal([227,227],**{'stddev':0.1,'mean':0.0})
weight_var = tf.Variable(weight_tensor)
weight_var_batch = tf.tile(tf.expand_dims(weight_var, axis=0), [batch_size, 1, 1])
matrix = tf.matmul(prev_net_2d,weight_var_batch)

Although batch_matmul doesn't exist anymore

like image 119
arkhy Avatar answered Oct 23 '22 05:10

arkhy