Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow 2.0: custom keras metric caused tf.function retracing warning

When I use the following custom metric (keras-style):

from sklearn.metrics import classification_report, f1_score
from tensorflow.keras.callbacks import Callback

class Metrics(Callback):
    def __init__(self, dev_data, classifier, dataloader):
        self.best_f1_score = 0.0
        self.dev_data = dev_data
        self.classifier = classifier
        self.predictor = Predictor(classifier, dataloader)
        self.dataloader = dataloader

    def on_epoch_end(self, epoch, logs=None):
        print("start to evaluate....")
        _, preds = self.predictor(self.dev_data)
        y_trues, y_preds = [self.dataloader.label_vector(v["label"]) for v in self.dev_data], preds
        f1 = f1_score(y_trues, y_preds, average="weighted")
        print(classification_report(y_trues, y_preds,
                                    target_names=self.dataloader.vocab.labels))
        if f1 > self.best_f1_score:
            self.best_f1_score = f1
            self.classifier.save_model()
            print("best metrics, save model...")

I obtained the following warning:

W1106 10:49:14.171694 4745115072 def_function.py:474] 6 out of the last 11 calls to .distributed_function at 0x14a3f9d90> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings is likely due to passing python objects instead of tensors. Also, tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. Please refer to https://www.tensorflow.org/beta/tutorials/eager/tf_function#python_or_tensor_args and https://www.tensorflow.org/api_docs/python/tf/function for more details.

like image 286
Sean Lee Avatar asked Jan 26 '23 15:01

Sean Lee


1 Answers

This warning occurs when a TF function is retraced because its arguments change in shape or dtype (for Tensors) or even in value (Python or np objects or variables).

In the general case, the fix is to use @tf.function(experimental_relax_shapes=True) before the definition of the custom function that you pass to Keras or TF somewhere. This tries to detect and avoid unnecessary retracing, but is not guaranteed to solve the issue.

In your case, i guess the Predictor class is a custom class, so place @tf.function(experimental_relax_shapes=True) before the definition of Predictor.predict().

like image 60
Louis Becquey Avatar answered Jan 31 '23 11:01

Louis Becquey