I would like to pass a dataset iterator into a function, but that function needs to know the length of the dataset. In the example below, I could pass len(datafiles)
into my_custom_fn()
function, but I'm wondering if I am able to extract the length of the data set from either iterator
, batch_x
, or batch_y
classes so that I do not have to add it as an input.
dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
[batch_x, batch_y] = iterator.get_next()
value = my_custom_fn(batch_x, batch_y)
# lots of other stuff
Thanks!
Edit: This solution does not work in my case: tf.data.Dataset: how to get the dataset size (number of elements in a epoch)?
After running
tf.data.Dataset.list_files('{}/*.dat')
tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0')[0])
returns
<tf.Tensor 'Shape_3:0' shape=(0,) dtype=int32>
I did find a solution that works for me. Add an iterator_scope to my code like:
with tf.name_scope('iter'):
dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
[batch_x, batch_y] = iterator.get_next()
value = my_custom_fn(batch_x, batch_y)
# lots of other stuff
Then from inside my_custom_fn
call:
def my_custom_fn(batch_x, batch_y):
filenames = batch_x.graph.get_operation_by_name(
'iter/InputDataSet/filenames').outputs[0]
n_epoch = sess.run(sess.graph.get_operation_by_name(
'iter/Iterator/count').outputs)[0]
batch_size = sess.run(sess.graph.get_operation_by_name(
'iter/Iterator/batch_size').outputs)[0]
# lots of other stuff
Not sure if this is the best way, but it seems to work. Happy to take any advice on this, as it seem a little hacky.
The length of an iterator
is unknown until you iterate through it. You could explicitly pass len(datafiles)
into the function, but if you are insistent upon the data's persistence, you could simply make the function an instance method and store the length of the dataset within the object for which the my_custom_fn
is a method.
Unfortunately, as an iterator
it does not store anything, it generates the data dynamically. However, as found in the source code for TensorFlow, there is a 'private' variable _batch_size
which stores the batch size. You can review the source code here: TensorFlow source.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With