Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow2 warning using @tffunction

Tags:

This example code from Tensorflow 2

writer = tf.summary.create_file_writer("/tmp/mylogs/tf_function")

@tf.function
def my_func(step):
  with writer.as_default():
    # other model code would go here
    tf.summary.scalar("my_metric", 0.5, step=step)

for step in range(100):
  my_func(step)
  writer.flush()

but it is throwing warnings.

WARNING:tensorflow:5 out of the last 5 calls to triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.

Is there a better way to do this?

like image 211
Codey McCodeface Avatar asked Nov 21 '19 10:11

Codey McCodeface


People also ask

What does TF function mean?

You can use tf. function to make graphs out of your programs. It is a transformation tool that creates Python-independent dataflow graphs out of your Python code. This will help you create performant and portable models, and it is required to use SavedModel .

What is TF Where?

The tf. where() function is used to returns the elements, either of first tensor or second tensor depending on the specified condition. If the given condition is true, it select from the first tensor else select form the second tensor. Syntax: tf.where (condition, a, b)

What is TF graph?

Graph execution means that tensor computations are executed as a TensorFlow graph, sometimes referred to as a tf. Graph or simply a "graph." Graphs are data structures that contain a set of tf. Operation objects, which represent units of computation; and tf.

What is TF in TensorFlow?

class Tensor : A tf. Tensor represents a multidimensional array of elements. class TensorArray : Class wrapping dynamic-sized, per-time-step, Tensor arrays.


1 Answers

tf.function has some "peculiarities". I highly recommend reading this article: https://www.tensorflow.org/tutorials/customization/performance

In this case, the problem is that the function is "retraced" (i.e. a new graph is built) every time you call with a different input signature. For tensors, input signature refers to shape and dtype, but for Python numbers, every new value is interpreted as "different". In this case, because you call the function with a step variable that changes every time, the function is retraced every single time as well. This will be extremely slow for "real" code (e.g. calling a model inside the function).

You can fix it by simply converting step to a tensor, in which case the different values will not count as a new input signature:

for step in range(100):
    step = tf.convert_to_tensor(step, dtype=tf.int64)
    my_func(step)
    writer.flush()

or use tf.range to get tensors directly:

for step in tf.range(100):
    step = tf.cast(step, tf.int64)
    my_func(step)
    writer.flush()

This should not produce warnings (and be much faster).

like image 164
xdurch0 Avatar answered Sep 22 '22 08:09

xdurch0