Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to make tf.data.Dataset return all of the elements in one call?

Is there an easy way to get the entire set of elements in a tf.data.Dataset? i.e. I want to set batch size of the Dataset to be the size of my dataset without specifically passing it the number of elements. This would be useful for validation dataset where I want to measure accuracy on the entire dataset in one go. I'm surprised there isn't a method to get the size of a tf.data.Dataset

like image 895
Milad Avatar asked Jan 06 '18 11:01

Milad


2 Answers

TensorFlow's get_single_element() is finally around which does exactly this - return all of the elements in one call.

This avoids the need of generating and using an iterator using .map() or iter() (which could be costly for big datasets).

get_single_element() returns a tensor (or a tuple or dict of tensors) encapsulating all the members of the dataset. We need to pass all the members of the dataset batched into a single element.

This can be used to get features as a tensor-array, or features and labels as a tuple or dictionary (of tensor-arrays) depending upon how the original dataset was created.

Check this answer on SO for an example that unpacks features and labels into a tuple of tensor-arrays.

like image 52
manisar Avatar answered Oct 05 '22 10:10

manisar


In Tensorflow 2.0

You can enumerate the dataset using as_numpy_iterator

for element in Xtrain.as_numpy_iterator(): 
  print(element) 
like image 33
Abhishek S Avatar answered Oct 05 '22 10:10

Abhishek S