Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Memory leak with tf.data

I'm creating a tf.data.Dataset inside a for loop and I noticed that the memory was not freed as one would expect after each iteration.

Is there a way to request from TensorFlow to free the memory?

I tried using tf.reset_default_graph(), I tried calling del on the relevant python objects but this does not work.

The only thing that seems to work is gc.collect(). Unfortunately, gc.collect does not work on some more complex examples.

Fully reproducible code:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import psutil
%matplotlib inline

memory_used = []
for i in range(500):
    data = tf.data.Dataset.from_tensor_slices(
                    np.random.uniform(size=(10, 500, 500)))\
                    .prefetch(64)\
                    .repeat(-1)\
                    .batch(3)
    data_it = data.make_initializable_iterator()
    next_element = data_it.get_next()

    with tf.Session() as sess:
        sess.run(data_it.initializer)
        sess.run(next_element)
    memory_used.append(psutil.virtual_memory().used / 2 ** 30)
    tf.reset_default_graph()

plt.plot(memory_used)
plt.title('Evolution of memory')
plt.xlabel('iteration')
plt.ylabel('memory used (GB)')

Evolution of memory usage

like image 713
BiBi Avatar asked Mar 17 '19 19:03

BiBi


People also ask

Does TF data run on GPU?

TensorFlow supports running computations on a variety of types of devices, including CPU and GPU.

What does TF data dataset do?

The tf. data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.

What does TF data autotune do?

AUTOTUNE , which will prompt the tf. data runtime to tune the value dynamically at runtime. Note that the prefetch transformation provides benefits any time there is an opportunity to overlap the work of a "producer" with the work of a "consumer."


2 Answers

The issue is that you're adding a new node to the graph to define the iterator after each iteration, a simple rule of thumb is never define new tensorflow variables inside a loop. To fix it move

data = tf.data.Dataset.from_tensor_slices(
            np.random.uniform(size=(10, 500, 500)))\
            .prefetch(64)\
            .repeat(-1)\
            .batch(3)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()

outside the for loop and just call sess.run(next_element) to fetch the next example and once youve gone through all the training/eval examples call sess.run(data_it) to reinitialize the iterator.

like image 183
Dpk Avatar answered Dec 07 '22 10:12

Dpk


This fix worked for me when I had a similar issue with TF 2.4

  1. sudo apt-get install libtcmalloc-minimal4
  2. LD_PRELOAD=/path/to/libtcmalloc_minimal.so.4 python example.py
like image 29
joakimedin Avatar answered Dec 07 '22 11:12

joakimedin