Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Memory leak for custom tensorflow training using @tf.function

I am trying to write my own training loop for TF2/Keras, following the official Keras walkthrough. The vanilla version works like a charm, but when I try to add the @tf.function decorator to my training step, some memory leak grabs all my memory and I lose control of my machine, does anyone know what is going on?.

The important parts of the code look like this:

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = siamese_network(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, siamese_network.trainable_weights)
    optimizer.apply_gradients(zip(grads, siamese_network.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

@tf.function
def test_step(x, y):
    val_logits = siamese_network(x, training=False)
    val_acc_metric.update_state(y, val_logits)
    val_prec_metric.update_state(y_batch_val, val_logits)
    val_rec_metric.update_state(y_batch_val, val_logits)


for epoch in range(epochs):
        step_time = 0
        epoch_time = time.time()
        print("Start of {} epoch".format(epoch))
        for step, (x_batch_train, y_batch_train) in enumerate(train_ds):
            if step > steps_epoch:
                break
           
            loss_value = train_step(x_batch_train, y_batch_train)
        train_acc = train_acc_metric.result()
        train_acc_metric.reset_states()
        
        for val_step,(x_batch_val, y_batch_val) in enumerate(test_ds):
            if val_step>validation_steps:
                break
            test_step(x_batch_val, y_batch_val)
         
        val_acc = val_acc_metric.result()
        val_prec = val_prec_metric.result()
        val_rec = val_rec_metric.result()

        val_acc_metric.reset_states()
        val_prec_metric.reset_states()
        val_rec_metric.reset_states()

If I comment on the @tf.function lines, the memory leak doesn't occur, but the step time is 3 times slower. My guess is that somehow the graph is bean created again within each epoch or something like that, but I have no idea how to solve it.

This is the tutorial I am following: https://keras.io/guides/writing_a_training_loop_from_scratch/

like image 776
este_banquito Avatar asked Apr 15 '21 21:04

este_banquito


People also ask

What is TF function in TensorFlow?

tf. function takes a regular function as input and returns a Function . A Function is a Python callable that builds TensorFlow graphs from the Python function. You use a Function in the same way as its Python equivalent. # Define a Python function.

When should I use TF function?

You can use tf. function to make graphs out of your programs. It is a transformation tool that creates Python-independent dataflow graphs out of your Python code. This will help you create performant and portable models, and it is required to use SavedModel .

Can you create a memory leak in Python?

The Python program, just like other programming languages, experiences memory leaks. Memory leaks in Python happen if the garbage collector doesn't clean and eliminate the unreferenced or unused data from Python.


1 Answers

tl;dr;

TensorFlow may be generating a new graph for each unique set of argument values passed into the decorated functions. Make sure you are passing consistently-shaped Tensor objects to test_step and train_step instead of python objects.

Details

This is a stab in the dark. While I've never tried @tf.function, I did find the following warnings in the documentation:

tf.function also treats any pure Python value as opaque objects, and builds a separate graph for each set of Python arguments that it encounters.

and

Caution: Passing python scalars or lists as arguments to tf.function will always build a new graph. To avoid this, pass numeric arguments as Tensors whenever possible

Finally:

A Function determines whether to reuse a traced ConcreteFunction by computing a cache key from an input's args and kwargs. A cache key is a key that identifies a ConcreteFunction based on the input args and kwargs of the Function call, according to the following rules (which may change):

  • The key generated for a tf.Tensor is its shape and dtype.
  • The key generated for a tf.Variable is a unique variable id.
  • The key generated for a Python primitive (like int, float, str) is its value.
  • The key generated for nested dicts, lists, tuples, namedtuples, and attrs is the flattened tuple of leaf-keys (see nest.flatten). (As a result of this flattening, calling a concrete function with a different nesting structure than the one used during tracing will result in a TypeError).
  • For all other Python types the key is unique to the object. This way a function or method is traced independently for each instance it is called with.

What I get from all this is that if you don't pass in a consistently-sized Tensor object to your @tf.function-ified function (perhaps you use Python collections or primitives instead), it is likely that you are creating a new graph version of your function with every distinct argument value you pass in. I'm guessing this could create the memory explosion behavior you're seeing. I can't tell how your test_ds and train_ds objects are being created, but you might want to make sure that they are created such that enumerate(blah_ds) returns tensors like in the tutorial, or at least convert the values to tensors before passing to your test_step and train_step functions.

like image 67
xdhmoore Avatar answered Sep 22 '22 13:09

xdhmoore