Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

dynamic_partition with dynamic num_partitions

Tags:

tensorflow

The variable num_partitions in the method tf.dynamic_partition is not a Tensor, but an int. Therefore, if the number of partitions isn't known in advance, it can't be inferred from the data by e.g., counting the number of unique values nor can't be given by a tf.placeholder. How can one use this method in such a dynamic scenario?

If it's impossible, a plausible workaround is to set the value of this parameter to some upper bound. Then at running time there will be some empty lists. The question is how those empty lists can be eliminated?

Thanks!

like image 576
Oren Avatar asked Dec 31 '16 13:12

Oren


1 Answers

To have fully dynamic partition you could have an op that returns a single Tensor with dynamic shape rather than Python-time fixed number of tensors, but the problem is that tensor dimensions must be rectangular and your partitions could have different lengths. To work-around it you might encode your variable size lists into strings which have shape () or use TensorArray. Here's a way to do it by using png encoding for arrays<=>string conversion

def dynamic_partition_png(vals, idx, max_partitions):
    """Encodes output of dynamic partition as a Tensor of png-encoded strings."""
    max_idx = tf.reduce_max(idx)
    max_vals = tf.reduce_max(idx)
    with tf.control_dependencies([tf.Assert(max_vals<256, ["vals must be <256"])]):
        outputs = tf.dynamic_partition(vals, idx, num_partitions=max_partitions)
    png_outputs = []
    dummy_png = tf.image.encode_png(([[[2]]]))
    not_empty_ops = [] # ops that detect empty lists that aren't at the end
    for i, o in enumerate(outputs):
        reshaped_o = tf.reshape(tf.cast(o, tf.uint8), [-1, 1, 1])
        png_output = tf.cond(tf.size(reshaped_o)>0, lambda: tf.image.encode_png(reshaped_o), lambda: dummy_png)
        png_outputs.append(png_output)
        not_empty_ops.append(tf.logical_or(i>max_idx, tf.size(reshaped_o)>0))
    packed_tensor = tf.pack(png_outputs)
    no_illegal_empty_lists = tf.reduce_all(tf.pack(not_empty_ops))
    with tf.control_dependencies([tf.Assert(no_illegal_empty_lists, ["empty lists must be last"])]):
        result = packed_tensor[:max_idx+1]
    return result

def decode(p):
    return tf.image.decode_png(p)[:, 0, 0]

sess = tf.Session()
vals = tf.constant([1,2,3,4,5])
idx = [0, 1, 1, 1, 1]
tf_vals = dynamic_partition_png(vals, idx, 3)
print(sess.run(decode(tf_vals[0]))) # => [1 2]
print(sess.run(decode(tf_vals[1]))) # => [3 4 5]
print(sess.run(decode(tf_vals[2]))) # => slice index 2 of dimension 0 out of bounds
like image 88
Yaroslav Bulatov Avatar answered Sep 20 '22 22:09

Yaroslav Bulatov