I have a variable a
of dimension (1, 5) which I want to 'tile' as many times as the size of my mini-batch. For example, if the mini-batch size is 32 then I want to construct a tensor c
of dimension (32, 5) where each row has values the same as the original (1, 5) variable a
.
But I only know the mini-batch size at run time: it's the size of dimension 0 of a placeholder b
: tf.shape(b)[0]
Here's my code to construct c:
a = tf.Variable(np.random.uniform(size=(1,5)))
b = tf.placeholder(shape=[None, 12], dtype=tf.float32)
batch_size = tf.shape(b)[0]
c = tf.tile(a, tf.pack([batch_size, 1]))
This runs fine. Howeverc.get_shape()
returns (?, ?). I don't understand why this doesn't return (?, 5) instead.
This is causing an issue later in my code when I construct a matrix variable W
with number of columns c.get_shape()[1]
which I expect to return 5 rather than ?.
Any help would be appreciated. Thanks.
[EDIT: This was fixed in a commit to TensorFlow on August 10, 2016.]
This is a known limitation of TensorFlow's shape inference: when the multiples
argument to tf.tile()
is a computed value (such as the result of tf.pack()
here), and its value is not trivially computable at graph construction time (in this case, because it depends on a tf.placeholder()
, which has no value until it is fed), the current shape inference will throw its hands up and declare that the shape is unknown (but with the same rank as the input, a
).
The current workaround is to use Tensor.set_shape()
, which allows you as the programmer to provide additional shape information when you know more than the shape inference does. For example, you could do:
a = tf.Variable(np.random.uniform(size=(1, 5)))
b = tf.placeholder(shape=[None, 12], dtype=tf.float32)
batch_size = tf.shape(b)[0]
c = tf.tile(a, tf.pack([batch_size, 1]))
c.set_shape([None, a.get_shape()[1]]) # or `c.set_shape([None, 5])`
However, we recently added some features that make it possible to propagate partially computed values that may be used as shapes, and this can be adapted to aid the shape function for tf.tile()
. I have created a GitHub issue to track this, and I have a fix being tested right now.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With