Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use tf.cond for batch processing

I want to use tf.cond(pred, fn1, fn2, name=None) for conditional branching. Let say I have two tensors: x, y. Each tensor is a batch of 0/1 and I want to use this tensors compression x < y as the source for tf.cond pred argument:

pred: A scalar determining whether to return the result of fn1 or fn2.

But if I am working with batches then it looks like I need to iterate over the source tensor inside the graph and make slices for every item in batch and apply tf.cond for every item. Looks suspiciously as for me. Why tf.cond not accept batch and only scalar? Can you advise what is the right way to use it with batch?

like image 480
Brans Ds Avatar asked Feb 10 '17 12:02

Brans Ds


People also ask

What does TF cond do?

cond stitches together the graph fragments created during the true_fn and false_fn calls with some additional graph nodes to ensure that the right branch gets executed depending on the value of pred . tf. cond supports nested structures as implemented in tensorflow. python.

What is TF Where?

The tf. where() function is used to returns the elements, either of first tensor or second tensor depending on the specified condition. If the given condition is true, it select from the first tensor else select form the second tensor. Syntax: tf.where (condition, a, b)


1 Answers

tf.where sounds like what you want: a vectorized selection between Tensors.

tf.cond is a control flow modifier: it determines which ops are executed, and so it's difficult to think of useful batch semantics.

We can also put together a mixture of these operations: an operation which slices based on a condition and passes those slices to two branches.

import tensorflow as tf
from tensorflow.python.util import nest

def slicing_where(condition, full_input, true_branch, false_branch):
  """Split `full_input` between `true_branch` and `false_branch` on `condition`.

  Args:
    condition: A boolean Tensor with shape [B_1, ..., B_N].
    full_input: A Tensor or nested tuple of Tensors of any dtype, each with
      shape [B_1, ..., B_N, ...], to be split between `true_branch` and
      `false_branch` based on `condition`.
    true_branch: A function taking a single argument, that argument having the
      same structure and number of batch dimensions as `full_input`. Receives
      slices of `full_input` corresponding to the True entries of
      `condition`. Returns a Tensor or nested tuple of Tensors, each with batch
      dimensions matching its inputs.
    false_branch: Like `true_branch`, but receives inputs corresponding to the
      false elements of `condition`. Returns a Tensor or nested tuple of Tensors
      (with the same structure as the return value of `true_branch`), but with
      batch dimensions matching its inputs.
  Returns:
    Interleaved outputs from `true_branch` and `false_branch`, each Tensor
    having shape [B_1, ..., B_N, ...].
  """
  full_input_flat = nest.flatten(full_input)
  true_indices = tf.where(condition)
  false_indices = tf.where(tf.logical_not(condition))
  true_branch_inputs = nest.pack_sequence_as(
      structure=full_input,
      flat_sequence=[tf.gather_nd(params=input_tensor, indices=true_indices)
                     for input_tensor in full_input_flat])
  false_branch_inputs = nest.pack_sequence_as(
      structure=full_input,
      flat_sequence=[tf.gather_nd(params=input_tensor, indices=false_indices)
                     for input_tensor in full_input_flat])
  true_outputs = true_branch(true_branch_inputs)
  false_outputs = false_branch(false_branch_inputs)
  nest.assert_same_structure(true_outputs, false_outputs)
  def scatter_outputs(true_output, false_output):
    batch_shape = tf.shape(condition)
    scattered_shape = tf.concat(
        [batch_shape, tf.shape(true_output)[tf.rank(batch_shape):]],
        0)
    true_scatter = tf.scatter_nd(
        indices=tf.cast(true_indices, tf.int32),
        updates=true_output,
        shape=scattered_shape)
    false_scatter = tf.scatter_nd(
        indices=tf.cast(false_indices, tf.int32),
        updates=false_output,
        shape=scattered_shape)
    return true_scatter + false_scatter
  result = nest.pack_sequence_as(
      structure=true_outputs,
      flat_sequence=[
          scatter_outputs(true_single_output, false_single_output)
          for true_single_output, false_single_output
          in zip(nest.flatten(true_outputs), nest.flatten(false_outputs))])
  return result

Some examples:

vector_test = slicing_where(
    condition=tf.equal(tf.range(10) % 2, 0),
    full_input=tf.range(10, dtype=tf.float32),
    true_branch=lambda x: 0.2 + x,
    false_branch=lambda x: 0.1 + x)

cross_range = (tf.range(10, dtype=tf.float32)[:, None]
               * tf.range(10, dtype=tf.float32)[None, :])
matrix_test = slicing_where(
    condition=tf.equal(tf.range(10) % 3, 0),
    full_input=cross_range,
    true_branch=lambda x: -x,
    false_branch=lambda x: x + 0.1)

with tf.Session():
  print(vector_test.eval())
  print(matrix_test.eval())

Prints:

[ 0.2         1.10000002  2.20000005  3.0999999   4.19999981  5.0999999
  6.19999981  7.0999999   8.19999981  9.10000038]
[[  0.           0.           0.           0.           0.           0.
    0.           0.           0.           0.        ]
 [  0.1          1.10000002   2.0999999    3.0999999    4.0999999
    5.0999999    6.0999999    7.0999999    8.10000038   9.10000038]
 [  0.1          2.0999999    4.0999999    6.0999999    8.10000038
   10.10000038  12.10000038  14.10000038  16.10000038  18.10000038]
 [  0.          -3.          -6.          -9.         -12.         -15.
  -18.         -21.         -24.         -27.        ]
 [  0.1          4.0999999    8.10000038  12.10000038  16.10000038
   20.10000038  24.10000038  28.10000038  32.09999847  36.09999847]
 [  0.1          5.0999999   10.10000038  15.10000038  20.10000038
   25.10000038  30.10000038  35.09999847  40.09999847  45.09999847]
 [  0.          -6.         -12.         -18.         -24.         -30.
  -36.         -42.         -48.         -54.        ]
 [  0.1          7.0999999   14.10000038  21.10000038  28.10000038
   35.09999847  42.09999847  49.09999847  56.09999847  63.09999847]
 [  0.1          8.10000038  16.10000038  24.10000038  32.09999847
   40.09999847  48.09999847  56.09999847  64.09999847  72.09999847]
 [  0.          -9.         -18.         -27.         -36.         -45.
  -54.         -63.         -72.         -81.        ]]
like image 62
Allen Lavoie Avatar answered Oct 31 '22 23:10

Allen Lavoie