Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Determining if A Value is in a Set in TensorFlow

The tf.logical_or, tf.logical_and, and tf.select functions are very useful.

However, suppose you have value x, and you wanted to see if it was in a set(a, b, c, d, e). In python you would simply write:

if x in set([a, b, c, d, e]):
  # Do some action.

As far as I can tell, the only way to do this in TensorFlow, is to have nested 'tf.logical_or' along with 'tf.equal'. I provided just one iteration of this concept below:

tf.logical_or(
    tf.logical_or(tf.equal(x, a), tf.equal(x, b)),
    tf.logical_or(tf.equal(x, c), tf.equal(x, d))
)

I feel that there must be an easier way to do this in TensorFlow. Is there?

like image 949
LeavesBreathe Avatar asked Jan 05 '16 17:01

LeavesBreathe


People also ask

How do you check the value of a tensor?

The easiest[A] way to evaluate the actual value of a Tensor object is to pass it to the Session. run() method, or call Tensor. eval() when you have a default session (i.e. in a with tf. Session(): block, or see below).

How do you find the value of TensorFlow variable?

To get the current value of a variable x in TensorFlow 2, you can simply print it with print(x) . This prints a representation of the tf.

What is tf Argmax?

tf.argmax. Returns the index with the largest value across axes of a tensor. tf.equal. Returns the truth value of (x == y) element-wise.

How do I assign a value in TensorFlow?

Tensorflow variables represent the tensors whose values can be changed by running operations on them. The assign() is the method available in the Variable class which is used to assign the new tf. Tensor to the variable. The new value must have the same shape and dtype as the old Variable value.


3 Answers

Here's two solutions, we want to check if query is in whitelist

whitelist = tf.constant(["CUISINE", "DISH", "RESTAURANT", "ADDRESS"])
query = "RESTAURANT"

#use broadcasting for element-wise tensor operation
broadcast_equal = tf.equal(whitelist, query)

#method 1: using tensor ops
broadcast_equal_int = tf.cast(broadcast_equal, tf.int8)
broadcast_sum = tf.reduce_sum(broadcast_equal_int)

#method 2: using some tf.core API
nz_cnt = tf.count_nonzero(broadcast_equal)

sess.run([broadcast_equal, broadcast_sum, nz_cnt])
#=> [array([False, False,  True, False]), 1, 1]

So if the output is > 0 then the item is in the set.

like image 87
eggie5 Avatar answered Oct 20 '22 08:10

eggie5


To provide a more concrete answer, say you want to check whether the last dimension of the tensor x contains any value from a 1D tensor s, you could do the following:

tile_multiples = tf.concat([tf.ones(tf.shape(tf.shape(x)), dtype=tf.int32), tf.shape(s)], axis=0)
x_tile = tf.tile(tf.expand_dims(x, -1), tile_multiples)
x_in_s = tf.reduce_any(tf.equal(x_tile, s), -1))

For example, for s and x:

s = tf.constant([3, 4])
x = tf.constant([[[1, 2, 3, 0, 0], 
                  [4, 4, 4, 0, 0]], 
                 [[3, 5, 5, 6, 4], 
                  [4, 7, 3, 8, 9]]])

x has shape [2, 2, 5] and s has shape [2] so tile_multiples = [1, 1, 1, 2], meaning we will tile the last dimension of x 2 times (once for each element in s) along a new dimension. So, x_tile will look like:

[[[[1 1]
   [2 2]
   [3 3]
   [0 0]
   [0 0]]

  [[4 4]
   [4 4]
   [4 4]
   [0 0]
   [0 0]]]

 [[[3 3]
   [5 5]
   [5 5]
   [6 6]
   [4 4]]

  [[4 4]
   [7 7]
   [3 3]
   [8 8]
   [9 9]]]]

and x_in_s will compare each of the tiled values to one of the values in s. tf.reduce_any along the last dim will return true if any of the tiled values was in s, giving the final result:

[[[False False  True False False]
  [ True  True  True False False]]

 [[ True False False False  True]
  [ True False  True False False]]]
like image 43
Emma Strubell Avatar answered Oct 20 '22 06:10

Emma Strubell


Take a look at this related question: Count number of "True" values in boolean Tensor

You should be able to build a tensor consisting of [a, b, c, d, e] and then check if any of the rows is equal to x using tf.equal(.)

like image 3
Rafał Józefowicz Avatar answered Oct 20 '22 06:10

Rafał Józefowicz