Is it possible to use a placeholder input for the num_split parameter to tf.split()?
I would ideally like to do something like:
num_splits = tf.placeholder(tf.int32)
inputs = tf.placeholder(tf.int32, [5, None])
split_inputs = tf.split(1, num_splits, inputs)
TypeError: Expected int for argument 'num_split' not .
There is possibly something wrong with my approach. I am looking to enumerate across a dimension in a variable shape tensor. Thanks!
There's a general philosophy of "tensor in-tensor out" for core graph ops, so it may simplify things if you can restructure your computation to deal with single tensor of variable size instead of variable number of tensors.
Ops like pack
, unpack
, split
deal with multiple tensors but they compile into "tensor-in/tensor-out" ops during graph construction time, which is why num_splits
needs to be fixed. Ops like dynamic_partition
, dynamic_stitch
, dequeue_many
take over some of that functionality for single tensors with variable 0
-th dimension.
If you really need to deal with variable number of tensors, typical approach is break computation into multiple session.run
calls, with one input tensor per run
call, and tie things together using queues. There's a slice_input_producer
which splits variable sized input along 0'th dimension and produces a tensor for each row, so if you wanted to evaluate myfunction
in a loop on each row of inputs
you could do this
def myfunction(vector):
result = tf.reduce_sum(vector)
print_result = tf.Print(result, [result], "myfunction called ")
return print_result
MAX_ROWS = 10
# input matrix with 2 columns and unknown number of rows (<MAX_ROWS)
inputs = tf.placeholder(tf.int32, [None, 2])
# copy of inputs, will need to have a persistent copy of it because we will
# be fetching rows in different session.run calls
data = tf.Variable(inputs, validate_shape=False)
# input producer that iterates over the rows and pushes them onto Queue
row = tf.train.slice_input_producer([data], num_epochs=1, shuffle=False)[0]
myfunction_op = myfunction(row)
# this op will save placeholder values into the variable
init_op = tf.initialize_all_variables()
# Coordinator is not necessary in this case, but you'll need it if you have
# more than one Queue in order to close all queues together
sess = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
sess.run([init_op], feed_dict={inputs:[[0, 0], [1, 1], [2, 2]]})
try:
for i in range(MAX_ROWS):
sess.run([myfunction_op])
except tf.errors.OutOfRangeError:
print('Done iterating')
finally:
# When done, ask other threads to stop.
coord.request_stop()
If you run this, you should see
myfunction called [0]
myfunction called [2]
myfunction called [4]
Done iterating
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With