Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow 2: Getting "WARNING:tensorflow:9 out of the last 9 calls to <function> triggered tf.function retracing. Tracing is expensive"

I think this error is coming from a problem with shapes, but I have no idea where. The complete error message suggests to do the following:

Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing.

When I enter this argument in the function decorator, it does work.

@tf.function(experimental_relax_shapes=True)

What can the cause be? Here's the full code:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
print(f'Tensorflow version {tf.__version__}')
from tensorflow import keras
from tensorflow.keras.layers import Dense, Conv1D, GlobalAveragePooling1D, Embedding
import tensorflow_datasets as tfds
from tensorflow.keras.models import Model

(train_data, test_data), info = tfds.load('imdb_reviews/subwords8k',
                                          split=[tfds.Split.TRAIN, tfds.Split.TEST],
                                          as_supervised=True, with_info=True)

padded_shapes = ([None], ())

train_dataset = train_data.shuffle(25000).\
    padded_batch(padded_shapes=padded_shapes, batch_size=16)
test_dataset = test_data.shuffle(25000).\
    padded_batch(padded_shapes=padded_shapes, batch_size=16)

n_words = info.features['text'].encoder.vocab_size


class ConvModel(Model):
    def __init__(self):
        super(ConvModel, self).__init__()
        self.embe = Embedding(n_words, output_dim=16)
        self.conv = Conv1D(32, kernel_size=6, activation='elu')
        self.glob = GlobalAveragePooling1D()
        self.dens = Dense(2)

    def call(self, x, training=None, mask=None):
        x = self.embe(x)
        x = self.conv(x)
        x = self.glob(x)
        x = self.dens(x)
        return x


conv = ConvModel()

conv(next(iter(train_dataset))[0])

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_loss = tf.keras.metrics.Mean()
test_loss = tf.keras.metrics.Mean()

train_acc = tf.keras.metrics.CategoricalAccuracy()
test_acc = tf.keras.metrics.CategoricalAccuracy()

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)


@tf.function
def train_step(inputs, labels):
    with tf.GradientTape() as tape:
        logits = conv(inputs, training=True)
        loss = loss_object(labels, logits)
        train_loss(loss)
        train_acc(logits, labels)

    gradients = tape.gradient(loss, conv.trainable_variables)
    optimizer.apply_gradients(zip(gradients, conv.trainable_variables))


@tf.function
def test_step(inputs, labels):
    logits = conv(inputs, training=False)
    loss = loss_object(labels, logits)
    test_loss(loss)
    test_acc(logits, labels)


def learn():
    train_loss.reset_states()
    test_loss.reset_states()
    train_acc.reset_states()
    test_acc.reset_states()

    for text, target in train_dataset:
        train_step(inputs=text, labels=target)

    for text, target in test_dataset:
        test_step(inputs=text, labels=target)


def main(epochs=2):
    for epoch in tf.range(1, epochs + 1):
        learn()
        template = 'TRAIN LOSS {:>5.3f} TRAIN ACC {:.2f} TEST LOSS {:>5.3f} TEST ACC {:.2f}'

        print(template.format(
            train_loss.result(),
            train_acc.result(),
            test_loss.result(),
            test_acc.result()
        ))

if __name__ == '__main__':
    main(epochs=1)
like image 451
Nicolas Gervais Avatar asked Oct 16 '22 03:10

Nicolas Gervais


1 Answers

TF/DR: Root-cause of this error is due to change in shape of train_data which varies from batch to batch. Fixing the size/shape of train_data resolves this tracing warning. I changed the following line, then everything works as expected. Full gist is here

padded_shapes = ([9000], ())#None.

Details:

As mentioned in the warning message

WARNING:tensorflow:10 out of the last 11 calls to <function train_step at 0x7f4825f6d400> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing.

this retracing warning happens because of the three reasons mentioned in the warning message. Reason (1) is not the root-cause because @tf.function is not called in a loop, also reason (3) is not the root-cause because both the arguments of train_step and test_step are tensor objects. So the root-cause is the reason (2) mentioned in the warning.

When I printed the size of train_data, it printed different sizes. So I tried to pad train_data so that shape is same for all the batches.

 padded_shapes = ([9000], ())#None.  # this line throws tracing error as the shape of text is varying for each step in an epoch.
    # as the data size is varying, tf.function will start retracing it
    # For the demonstration, I used 9000 as max length, but please change it accordingly 
like image 87
Vishnuvardhan Janapati Avatar answered Oct 18 '22 14:10

Vishnuvardhan Janapati