Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Equality comparison does not work inside TensorFlow 2.0 tf.function()

Following the discussion on TensorFlow 2.0 AutoGraphs, I've been playing around and noticed that inequality comparisons such as > and < are specified directly, whereas equality comparisons are represented using tf.equal.

Here's an example to demonstrate. This function uses > operator and works well when called:

@tf.function
def greater_than_zero(value):
    return value > 0

greater_than_zero(tf.constant(1))
#  <tf.Tensor: id=1377, shape=(), dtype=bool, numpy=True>
greater_than_zero(tf.constant(-1))
# <tf.Tensor: id=1380, shape=(), dtype=bool, numpy=False>

Here is another function that uses equality comparison, but does not work:

@tf.function
def equal_to_zero(value):
    return value == 0

equal_to_zero(tf.constant(1))
# <tf.Tensor: id=1389, shape=(), dtype=bool, numpy=False>  # OK...

equal_to_zero(tf.constant(0))
# <tf.Tensor: id=1392, shape=(), dtype=bool, numpy=False>  # WHAT?

If I change the == equality comparison to tf.equal, it will work.

@tf.function
def equal_to_zero2(value):
    return tf.equal(value, 0)

equal_to_zero2(tf.constant(0))
# <tf.Tensor: id=1402, shape=(), dtype=bool, numpy=True>

My question is: Why does using inequality comparison operators work inside tf.function functions, whereas equality comparisons do not?

like image 680
cs95 Avatar asked Jun 16 '19 05:06

cs95


People also ask

How do you know if two tensors are equal in TensorFlow?

To check if two tensors are equal, one can use tf. equal .

What is tf function in TensorFlow?

tf. function is a decorator function provided by Tensorflow 2.0 that converts regular python code to a callable Tensorflow graph function, which is usually more performant and python independent. It is used to create portable Tensorflow models.

What is the difference between TensorFlow 1 and 2?

(As per the TensorFlow team) It is important to understand that there is no battle of TensorFlow 1.0 vs TensorFlow 2.0 as TensorFlow 2.0 is the updated version and hence clearly better and smarter. It was built keeping in mind the drawbacks of TensorFlow 1.0 which was particularly hard to use and understand.

What is tf session in TensorFlow 2?

TensorFlow Session is a session object which encapsulates the environment in which Operation objects are executed, and data objects are evaluated. TensorFlow requires a session to execute an operation and retrieve its calculated value. A session may own several resources, for example, tf. QueueBase, tf. Variable, tf.


1 Answers

I analyzed this behavior in part 3 of the article "Analysing tf.function to discover Autograph strengths and subtleties" (and I highly recommend reading all the 3 parts to understand how to correctly write a function before decorating it with tf.function - links at the bottom of the answer).

For the __eq__ and tf.equal question, the answer is:

In short: the __eq__ operator (for tf.Tensor) has been overridden, but the operator does not use tf.equal to check for the Tensor equality, it just checks for the Python variable identity (if you are familiar with the Java programming language, this is precisely like the == operator used on string objects). The reason is that the tf.Tensor object needs to be hashable since it is used everywhere in the Tensorflow codebase as key for dict objects.

While for all the other operators, the answer is that AutoGraph doesn't convert Python operators to TensorFlow logical operators. In the section How AutoGraph (don’t) converts the operators I showed that every Python operator gets converted to a graph representation that is always evaluated as false.

In fact, the following example produces as output "wat"

@tf.function
def if_elif(a, b):
  if a > b:
    tf.print("a > b", a, b)
  elif a == b:
    tf.print("a == b", a, b)
  elif a < b:
    tf.print("a < b", a, b)
  else:
    tf.print("wat")
x = tf.constant(1)
if_elif(x,x)

In practice, AutoGraph is unable to convert Python code to graph code; we have to help it using only the TensorFlow primitives. In that case, your code will work as you expect.

@tf.function
def if_elif(a, b):
  if tf.math.greater(a, b):
    tf.print("a > b", a, b)
  elif tf.math.equal(a, b):
    tf.print("a == b", a, b)
  elif tf.math.less(a, b):
    tf.print("a < b", a, b)
  else:
    tf.print("wat")

I let here the links to all the three articles, I guess you'll find them usefult:

part 1, part 2, part 3

like image 77
nessuno Avatar answered Oct 10 '22 04:10

nessuno