The variable num_partitions
in the method tf.dynamic_partition
is not a Tensor
, but an int
.
Therefore, if the number of partitions isn't known in advance, it can't be inferred from the data by e.g., counting the number of unique values nor can't be given by a tf.placeholder
.
How can one use this method in such a dynamic scenario?
If it's impossible, a plausible workaround is to set the value of this parameter to some upper bound. Then at running time there will be some empty lists. The question is how those empty lists can be eliminated?
Thanks!
To have fully dynamic partition you could have an op that returns a single Tensor with dynamic shape rather than Python-time fixed number of tensors, but the problem is that tensor dimensions must be rectangular and your partitions could have different lengths. To work-around it you might encode your variable size lists into strings which have shape () or use TensorArray
. Here's a way to do it by using png encoding for arrays<=>string conversion
def dynamic_partition_png(vals, idx, max_partitions):
"""Encodes output of dynamic partition as a Tensor of png-encoded strings."""
max_idx = tf.reduce_max(idx)
max_vals = tf.reduce_max(idx)
with tf.control_dependencies([tf.Assert(max_vals<256, ["vals must be <256"])]):
outputs = tf.dynamic_partition(vals, idx, num_partitions=max_partitions)
png_outputs = []
dummy_png = tf.image.encode_png(([[[2]]]))
not_empty_ops = [] # ops that detect empty lists that aren't at the end
for i, o in enumerate(outputs):
reshaped_o = tf.reshape(tf.cast(o, tf.uint8), [-1, 1, 1])
png_output = tf.cond(tf.size(reshaped_o)>0, lambda: tf.image.encode_png(reshaped_o), lambda: dummy_png)
png_outputs.append(png_output)
not_empty_ops.append(tf.logical_or(i>max_idx, tf.size(reshaped_o)>0))
packed_tensor = tf.pack(png_outputs)
no_illegal_empty_lists = tf.reduce_all(tf.pack(not_empty_ops))
with tf.control_dependencies([tf.Assert(no_illegal_empty_lists, ["empty lists must be last"])]):
result = packed_tensor[:max_idx+1]
return result
def decode(p):
return tf.image.decode_png(p)[:, 0, 0]
sess = tf.Session()
vals = tf.constant([1,2,3,4,5])
idx = [0, 1, 1, 1, 1]
tf_vals = dynamic_partition_png(vals, idx, 3)
print(sess.run(decode(tf_vals[0]))) # => [1 2]
print(sess.run(decode(tf_vals[1]))) # => [3 4 5]
print(sess.run(decode(tf_vals[2]))) # => slice index 2 of dimension 0 out of bounds
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With