Before tensorflow 2.0-beta, to retrieve the first element from tf.data.Dataset, we may use a iterator as shown below:
#!/usr/bin/python
import tensorflow as tf
train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])
iterator = train_dataset.make_one_shot_iterator()
with tf.Session() as sess:
# 1.0 will be printed.
print (sess.run(iterator.get_next()))
In tensorflow 2.0-beta, it seems that the above one-shot-iterator is now deprecated. To print out the entire elements we may use the following for approach.
#!/usr/bin/python
import tensorflow as tf
train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])
for data in train_dataset:
# 1.0, 2.0, 3.0, and 4.0 will be printed.
print (data.numpy())
However, if we only want to retrieve exactly one element from tf.data.Dataset, then how can we do with tensorflow 2.0 beta? It seems that next(train_dataset)
is not supported. It could be done easily with the old one shot iterator as shown above, but it's not very obvious with the new for based approach.
Any suggestion is welcomed.
What works with eager execution enabled (default in TF 2.0) is:
elem = next(iter(train_dataset))
Explanation: Datasets have an __iter__
member function to support the for elem in dataset:
approach. This returns an iterator. The Python function iter
does just that: Basically calls the __iter__
function. next
then returns the first element that iterator produces.
I haven't found a solution which works for non-eager execution though, as that currently raises RuntimeError: __iter__() is only supported inside of tf.function or when eager execution is enabled.
You can .take(1)
from the dataset:
for elem in train_dataset.take(1):
print (elem.numpy())
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With