I want to use tf.cond(pred, fn1, fn2, name=None) for conditional branching. Let say I have two tensors: x, y
. Each tensor is a batch of 0/1 and I want to use this tensors compression x < y
as the source for
tf.cond pred
argument:
pred: A scalar determining whether to return the result of fn1 or fn2.
But if I am working with batches then it looks like I need to iterate over the source tensor inside the graph and make slices for every item in batch and apply tf.cond for every item. Looks suspiciously as for me. Why tf.cond not accept batch and only scalar? Can you advise what is the right way to use it with batch?
cond stitches together the graph fragments created during the true_fn and false_fn calls with some additional graph nodes to ensure that the right branch gets executed depending on the value of pred . tf. cond supports nested structures as implemented in tensorflow. python.
The tf. where() function is used to returns the elements, either of first tensor or second tensor depending on the specified condition. If the given condition is true, it select from the first tensor else select form the second tensor. Syntax: tf.where (condition, a, b)
tf.where sounds like what you want: a vectorized selection between Tensors.
tf.cond
is a control flow modifier: it determines which ops are executed, and so it's difficult to think of useful batch semantics.
We can also put together a mixture of these operations: an operation which slices based on a condition and passes those slices to two branches.
import tensorflow as tf
from tensorflow.python.util import nest
def slicing_where(condition, full_input, true_branch, false_branch):
"""Split `full_input` between `true_branch` and `false_branch` on `condition`.
Args:
condition: A boolean Tensor with shape [B_1, ..., B_N].
full_input: A Tensor or nested tuple of Tensors of any dtype, each with
shape [B_1, ..., B_N, ...], to be split between `true_branch` and
`false_branch` based on `condition`.
true_branch: A function taking a single argument, that argument having the
same structure and number of batch dimensions as `full_input`. Receives
slices of `full_input` corresponding to the True entries of
`condition`. Returns a Tensor or nested tuple of Tensors, each with batch
dimensions matching its inputs.
false_branch: Like `true_branch`, but receives inputs corresponding to the
false elements of `condition`. Returns a Tensor or nested tuple of Tensors
(with the same structure as the return value of `true_branch`), but with
batch dimensions matching its inputs.
Returns:
Interleaved outputs from `true_branch` and `false_branch`, each Tensor
having shape [B_1, ..., B_N, ...].
"""
full_input_flat = nest.flatten(full_input)
true_indices = tf.where(condition)
false_indices = tf.where(tf.logical_not(condition))
true_branch_inputs = nest.pack_sequence_as(
structure=full_input,
flat_sequence=[tf.gather_nd(params=input_tensor, indices=true_indices)
for input_tensor in full_input_flat])
false_branch_inputs = nest.pack_sequence_as(
structure=full_input,
flat_sequence=[tf.gather_nd(params=input_tensor, indices=false_indices)
for input_tensor in full_input_flat])
true_outputs = true_branch(true_branch_inputs)
false_outputs = false_branch(false_branch_inputs)
nest.assert_same_structure(true_outputs, false_outputs)
def scatter_outputs(true_output, false_output):
batch_shape = tf.shape(condition)
scattered_shape = tf.concat(
[batch_shape, tf.shape(true_output)[tf.rank(batch_shape):]],
0)
true_scatter = tf.scatter_nd(
indices=tf.cast(true_indices, tf.int32),
updates=true_output,
shape=scattered_shape)
false_scatter = tf.scatter_nd(
indices=tf.cast(false_indices, tf.int32),
updates=false_output,
shape=scattered_shape)
return true_scatter + false_scatter
result = nest.pack_sequence_as(
structure=true_outputs,
flat_sequence=[
scatter_outputs(true_single_output, false_single_output)
for true_single_output, false_single_output
in zip(nest.flatten(true_outputs), nest.flatten(false_outputs))])
return result
Some examples:
vector_test = slicing_where(
condition=tf.equal(tf.range(10) % 2, 0),
full_input=tf.range(10, dtype=tf.float32),
true_branch=lambda x: 0.2 + x,
false_branch=lambda x: 0.1 + x)
cross_range = (tf.range(10, dtype=tf.float32)[:, None]
* tf.range(10, dtype=tf.float32)[None, :])
matrix_test = slicing_where(
condition=tf.equal(tf.range(10) % 3, 0),
full_input=cross_range,
true_branch=lambda x: -x,
false_branch=lambda x: x + 0.1)
with tf.Session():
print(vector_test.eval())
print(matrix_test.eval())
Prints:
[ 0.2 1.10000002 2.20000005 3.0999999 4.19999981 5.0999999
6.19999981 7.0999999 8.19999981 9.10000038]
[[ 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. ]
[ 0.1 1.10000002 2.0999999 3.0999999 4.0999999
5.0999999 6.0999999 7.0999999 8.10000038 9.10000038]
[ 0.1 2.0999999 4.0999999 6.0999999 8.10000038
10.10000038 12.10000038 14.10000038 16.10000038 18.10000038]
[ 0. -3. -6. -9. -12. -15.
-18. -21. -24. -27. ]
[ 0.1 4.0999999 8.10000038 12.10000038 16.10000038
20.10000038 24.10000038 28.10000038 32.09999847 36.09999847]
[ 0.1 5.0999999 10.10000038 15.10000038 20.10000038
25.10000038 30.10000038 35.09999847 40.09999847 45.09999847]
[ 0. -6. -12. -18. -24. -30.
-36. -42. -48. -54. ]
[ 0.1 7.0999999 14.10000038 21.10000038 28.10000038
35.09999847 42.09999847 49.09999847 56.09999847 63.09999847]
[ 0.1 8.10000038 16.10000038 24.10000038 32.09999847
40.09999847 48.09999847 56.09999847 64.09999847 72.09999847]
[ 0. -9. -18. -27. -36. -45.
-54. -63. -72. -81. ]]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With