Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using a variable for num_splits for tf.split()

Is it possible to use a placeholder input for the num_split parameter to tf.split()?

I would ideally like to do something like:

num_splits = tf.placeholder(tf.int32)
inputs = tf.placeholder(tf.int32, [5, None])
split_inputs = tf.split(1, num_splits, inputs)

TypeError: Expected int for argument 'num_split' not .

There is possibly something wrong with my approach. I am looking to enumerate across a dimension in a variable shape tensor. Thanks!

like image 515
bbhat Avatar asked Jan 24 '16 00:01

bbhat


1 Answers

There's a general philosophy of "tensor in-tensor out" for core graph ops, so it may simplify things if you can restructure your computation to deal with single tensor of variable size instead of variable number of tensors.

Ops like pack, unpack, split deal with multiple tensors but they compile into "tensor-in/tensor-out" ops during graph construction time, which is why num_splits needs to be fixed. Ops like dynamic_partition, dynamic_stitch, dequeue_many take over some of that functionality for single tensors with variable 0-th dimension.

If you really need to deal with variable number of tensors, typical approach is break computation into multiple session.run calls, with one input tensor per run call, and tie things together using queues. There's a slice_input_producer which splits variable sized input along 0'th dimension and produces a tensor for each row, so if you wanted to evaluate myfunction in a loop on each row of inputs you could do this

def myfunction(vector):
  result = tf.reduce_sum(vector)
  print_result = tf.Print(result, [result], "myfunction called ")
  return print_result

MAX_ROWS = 10

# input matrix with 2 columns and unknown number of rows (<MAX_ROWS)
inputs = tf.placeholder(tf.int32, [None, 2])
# copy of inputs, will need to have a persistent copy of it because we will
# be fetching rows in different session.run calls
data = tf.Variable(inputs, validate_shape=False)
# input producer that iterates over the rows and pushes them onto Queue
row = tf.train.slice_input_producer([data], num_epochs=1, shuffle=False)[0]
myfunction_op = myfunction(row)

# this op will save placeholder values into the variable
init_op = tf.initialize_all_variables()

# Coordinator is not necessary in this case, but you'll need it if you have
# more than one Queue in order to close all queues together
sess = tf.Session()
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

sess.run([init_op], feed_dict={inputs:[[0, 0], [1, 1], [2, 2]]})

try:
  for i in range(MAX_ROWS):
    sess.run([myfunction_op])
except tf.errors.OutOfRangeError:
  print('Done iterating')
finally:
  # When done, ask other threads to stop.
  coord.request_stop()

If you run this, you should see

myfunction called [0]
myfunction called [2]
myfunction called [4]
Done iterating
like image 178
Yaroslav Bulatov Avatar answered Sep 23 '22 22:09

Yaroslav Bulatov