I am trying to write my own training loop for TF2/Keras
, following the official Keras walkthrough. The vanilla version works like a charm, but when I try to add the @tf.function
decorator to my training step, some memory leak grabs all my memory and I lose control of my machine, does anyone know what is going on?.
The important parts of the code look like this:
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = siamese_network(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, siamese_network.trainable_weights)
optimizer.apply_gradients(zip(grads, siamese_network.trainable_weights))
train_acc_metric.update_state(y, logits)
return loss_value
@tf.function
def test_step(x, y):
val_logits = siamese_network(x, training=False)
val_acc_metric.update_state(y, val_logits)
val_prec_metric.update_state(y_batch_val, val_logits)
val_rec_metric.update_state(y_batch_val, val_logits)
for epoch in range(epochs):
step_time = 0
epoch_time = time.time()
print("Start of {} epoch".format(epoch))
for step, (x_batch_train, y_batch_train) in enumerate(train_ds):
if step > steps_epoch:
break
loss_value = train_step(x_batch_train, y_batch_train)
train_acc = train_acc_metric.result()
train_acc_metric.reset_states()
for val_step,(x_batch_val, y_batch_val) in enumerate(test_ds):
if val_step>validation_steps:
break
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_prec = val_prec_metric.result()
val_rec = val_rec_metric.result()
val_acc_metric.reset_states()
val_prec_metric.reset_states()
val_rec_metric.reset_states()
If I comment on the @tf.function
lines, the memory leak doesn't occur, but the step time is 3 times slower. My guess is that somehow the graph is bean created again within each epoch or something like that, but I have no idea how to solve it.
This is the tutorial I am following: https://keras.io/guides/writing_a_training_loop_from_scratch/
tf. function takes a regular function as input and returns a Function . A Function is a Python callable that builds TensorFlow graphs from the Python function. You use a Function in the same way as its Python equivalent. # Define a Python function.
You can use tf. function to make graphs out of your programs. It is a transformation tool that creates Python-independent dataflow graphs out of your Python code. This will help you create performant and portable models, and it is required to use SavedModel .
The Python program, just like other programming languages, experiences memory leaks. Memory leaks in Python happen if the garbage collector doesn't clean and eliminate the unreferenced or unused data from Python.
TensorFlow may be generating a new graph for each unique set of argument values passed into the decorated functions. Make sure you are passing consistently-shaped Tensor
objects to test_step
and train_step
instead of python objects.
This is a stab in the dark. While I've never tried @tf.function
, I did find the following warnings in the documentation:
tf.function also treats any pure Python value as opaque objects, and builds a separate graph for each set of Python arguments that it encounters.
and
Caution: Passing python scalars or lists as arguments to tf.function will always build a new graph. To avoid this, pass numeric arguments as Tensors whenever possible
Finally:
A Function determines whether to reuse a traced ConcreteFunction by computing a cache key from an input's args and kwargs. A cache key is a key that identifies a ConcreteFunction based on the input args and kwargs of the Function call, according to the following rules (which may change):
- The key generated for a tf.Tensor is its shape and dtype.
- The key generated for a tf.Variable is a unique variable id.
- The key generated for a Python primitive (like int, float, str) is its value.
- The key generated for nested dicts, lists, tuples, namedtuples, and attrs is the flattened tuple of leaf-keys (see nest.flatten). (As a result of this flattening, calling a concrete function with a different nesting structure than the one used during tracing will result in a TypeError).
- For all other Python types the key is unique to the object. This way a function or method is traced independently for each instance it is called with.
What I get from all this is that if you don't pass in a consistently-sized Tensor object to your @tf.function
-ified function (perhaps you use Python collections or primitives instead), it is likely that you are creating a new graph version of your function with every distinct argument value you pass in. I'm guessing this could create the memory explosion behavior you're seeing. I can't tell how your test_ds
and train_ds
objects are being created, but you might want to make sure that they are created such that enumerate(blah_ds)
returns tensors like in the tutorial, or at least convert the values to tensors before passing to your test_step
and train_step
functions.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With