I'm creating a tf.data.Dataset
inside a for loop and I noticed that the memory was not freed as one would expect after each iteration.
Is there a way to request from TensorFlow to free the memory?
I tried using tf.reset_default_graph()
, I tried calling del
on the relevant python objects but this does not work.
The only thing that seems to work is gc.collect()
. Unfortunately, gc.collect
does not work on some more complex examples.
Fully reproducible code:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import psutil
%matplotlib inline
memory_used = []
for i in range(500):
data = tf.data.Dataset.from_tensor_slices(
np.random.uniform(size=(10, 500, 500)))\
.prefetch(64)\
.repeat(-1)\
.batch(3)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()
with tf.Session() as sess:
sess.run(data_it.initializer)
sess.run(next_element)
memory_used.append(psutil.virtual_memory().used / 2 ** 30)
tf.reset_default_graph()
plt.plot(memory_used)
plt.title('Evolution of memory')
plt.xlabel('iteration')
plt.ylabel('memory used (GB)')
TensorFlow supports running computations on a variety of types of devices, including CPU and GPU.
The tf. data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.
AUTOTUNE , which will prompt the tf. data runtime to tune the value dynamically at runtime. Note that the prefetch transformation provides benefits any time there is an opportunity to overlap the work of a "producer" with the work of a "consumer."
The issue is that you're adding a new node to the graph to define the iterator after each iteration, a simple rule of thumb is never define new tensorflow variables inside a loop. To fix it move
data = tf.data.Dataset.from_tensor_slices(
np.random.uniform(size=(10, 500, 500)))\
.prefetch(64)\
.repeat(-1)\
.batch(3)
data_it = data.make_initializable_iterator()
next_element = data_it.get_next()
outside the for loop and just call sess.run(next_element) to fetch the next example and once youve gone through all the training/eval examples call sess.run(data_it) to reinitialize the iterator.
This fix worked for me when I had a similar issue with TF 2.4
sudo apt-get install libtcmalloc-minimal4
LD_PRELOAD=/path/to/libtcmalloc_minimal.so.4 python example.py
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With