Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

When to use @tf.function decorator and when not? I know tf.function builds graph. But how to know when to build graphs?

I started by Tensorflow journey when it already came to 2.0.0, So never used graphs and sessions as in version1. But recently met tf.function and autographs which suits me. (but what i know is it is used only for train step)

Now when reading project code, many people use tf.function decorator on many other functions when they wanna build graphs. But i don't exactly get their point. How to know when to use graph and when not?

Can anyone help me?

like image 897
Leo Avatar asked May 05 '20 07:05

Leo


1 Answers

Solution

The decorator, @tf.function conveniently converts a python function to a static tensorflow graph. TensorFlow operates in eager mode by default since version 2.0.0. Although eager mode could help you in line-by-line execution, this comes with the pitfall of relatively slower TensorFlow-code execution when compared to static-graph. Converting a certain function into a static graph increases execution speed while training your model.

Quoting tf.function documentation:

Functions can be faster than eager code, especially for graphs with many small ops. But for graphs with a few expensive ops (like convolutions), you may not see much speedup.

The static graph is created once and does not get updated if the function is called repeatedly with different values (not passed as the input-arguments). You should avoid using @tf.function in such scenarios or update the function definition (if possible) to include all the necessary variability through the input-arguments. However, Now, if your function gets all its inputs through the function arguments, then if you apply @tf.function you will not see any problem.

Here is an example.

### When not to use @tf.function ###
# some variable that changes with time
var = timestamp()

@tf.function
def func(*args, **kwargs):
    # your code
    return var

In the example above, the function func() although depends on var, it does not access the variable var through its arguments. Thus, when @tf.function is applied for the first time, it creates a static-graph for func(). However, when the value of var changes in future, this will not get updated in the static-graph. See this for more clarity. Also, I would highly encourage you to see the references section.

For Debugging

Quoting source

You can use tf.config.experimental_run_functions_eagerly (which temporarily disables running functions as functions) for debugging purposes.

References

  1. Better performance with tf.function
  2. When to utilize tf.function
  3. TensorFlow 2.0: tf.function and AutoGraph
like image 180
CypherX Avatar answered Sep 22 '22 05:09

CypherX