Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow: Split a tensor into `batch_size` slices

I have a rank-3 tensor named tensor of shape [batch_size, axis_1, axis_2] and want to split it into batch_size slices along the first axis like so:

batch_size = tf.shape(tensor)[0]

batch_items = tf.split(tensor, num_or_size_splits=batch_size, axis=0)

Unfortunately, this doesn't work because the value of batch_size isn't yet known during construction of the graph.

How can I solve this?

I get this error:

TypeError: Expected int for argument 'num_split' not <tf.Tensor 'decoded_predictions/strided_slice_15:0' shape=() dtype=int32>.

Weirdly, trying to use batch_size in other TensorFlow functions seems to work:

tensor = tf.reshape(tensor, [batch_size, -1])

works fine despite the fact that the value of batch_size is unknown during graph construction.

Is the problem particularly with tf.split()?

like image 872
Alex Avatar asked Nov 08 '22 10:11

Alex


1 Answers

A work-around is to do:

batch_items = tf.map_fn(fn=lambda k: tensor[...,k],
                        elems=tf.range(batch_size),
                        dtype=tf.float32)

I'm still interested in better solutions though.

like image 55
Alex Avatar answered Nov 14 '22 21:11

Alex