Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow DataSet API causes graph size to explode

Tags:

tensorflow

I have a very bug data set for training.

I'm using the data set API like so:

self._dataset = tf.contrib.data.Dataset.from_tensor_slices((self._images_list, self._labels_list))

self._dataset = self._dataset.map(self.load_image)

self._dataset = self._dataset.batch(batch_size)
self._dataset = self._dataset.shuffle(buffer_size=shuffle_buffer_size)
self._dataset = self._dataset.repeat()

self._iterator = self._dataset.make_one_shot_iterator()

If I use for the training a small amount of the data then all is well. If I use all my data then TensorFlow will crash with this error: ValueError: GraphDef cannot be larger than 2GB.

It seems like TensorFlow tries to load all the data instead of loading only the data that it needs... not sure...

Any advice will be great!

Update... found a solution/workaround

according to this post: Tensorflow Dataset API doubles graph protobuff filesize

I replaced the make_one_shot_iterator() with make_initializable_iterator() and of course called the iterator initializer after creating the session:

init = tf.global_variables_initializer()
sess.run(init)
sess.run(train_data._iterator.initializer)

But I'm leaving the question open as to me it seems like a workaround and not a solution...

like image 873
Ohad Meir Avatar asked Oct 17 '22 06:10

Ohad Meir


1 Answers

https://www.tensorflow.org/guide/datasets#consuming_numpy_arrays

Note that the above code snippet will embed the features and labels arrays in your TensorFlow graph as tf.constant() operations. This works well for a small dataset, but wastes memory---because the contents of the array will be copied multiple times---and can run into the 2GB limit for the tf.GraphDef protocol buffer. As an alternative, you can define the Dataset in terms of tf.placeholder() tensors, and feed the NumPy arrays when you initialize an Iterator over the dataset.

Instead of using

dataset = tf.data.Dataset.from_tensor_slices((features, labels))

Use

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
like image 77
draupnie Avatar answered Oct 21 '22 03:10

draupnie