I have a rank-3 tensor named tensor
of shape [batch_size, axis_1, axis_2]
and want to split it into batch_size
slices along the first axis like so:
batch_size = tf.shape(tensor)[0]
batch_items = tf.split(tensor, num_or_size_splits=batch_size, axis=0)
Unfortunately, this doesn't work because the value of batch_size
isn't yet known during construction of the graph.
How can I solve this?
I get this error:
TypeError: Expected int for argument 'num_split' not <tf.Tensor 'decoded_predictions/strided_slice_15:0' shape=() dtype=int32>.
Weirdly, trying to use batch_size
in other TensorFlow functions seems to work:
tensor = tf.reshape(tensor, [batch_size, -1])
works fine despite the fact that the value of batch_size
is unknown during graph construction.
Is the problem particularly with tf.split()
?
A work-around is to do:
batch_items = tf.map_fn(fn=lambda k: tensor[...,k],
elems=tf.range(batch_size),
dtype=tf.float32)
I'm still interested in better solutions though.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With