Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Determine number of records in tf.data.Dataset Tensorflow [duplicate]

I would like to pass a dataset iterator into a function, but that function needs to know the length of the dataset. In the example below, I could pass len(datafiles) into my_custom_fn() function, but I'm wondering if I am able to extract the length of the data set from either iterator, batch_x, or batch_y classes so that I do not have to add it as an input.

dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
[batch_x, batch_y] = iterator.get_next()
value = my_custom_fn(batch_x, batch_y)
# lots of other stuff

Thanks!

Edit: This solution does not work in my case: tf.data.Dataset: how to get the dataset size (number of elements in a epoch)?

After running

tf.data.Dataset.list_files('{}/*.dat')
tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0')[0])

returns

<tf.Tensor 'Shape_3:0' shape=(0,) dtype=int32>

I did find a solution that works for me. Add an iterator_scope to my code like:

with tf.name_scope('iter'):
    dataset = tf.data.FixedLengthRecordDataset(datafiles, record_bytes)
    iterator = dataset.make_initializable_iterator()
    sess.run(iterator.initializer)
    [batch_x, batch_y] = iterator.get_next()
value = my_custom_fn(batch_x, batch_y)
# lots of other stuff

Then from inside my_custom_fn call:

def my_custom_fn(batch_x, batch_y):
    filenames = batch_x.graph.get_operation_by_name(
                  'iter/InputDataSet/filenames').outputs[0]
    n_epoch = sess.run(sess.graph.get_operation_by_name(
                  'iter/Iterator/count').outputs)[0]
    batch_size = sess.run(sess.graph.get_operation_by_name(
                  'iter/Iterator/batch_size').outputs)[0]
    # lots of other stuff

Not sure if this is the best way, but it seems to work. Happy to take any advice on this, as it seem a little hacky.

like image 479
John Avatar asked Sep 10 '18 19:09

John


1 Answers

The length of an iterator is unknown until you iterate through it. You could explicitly pass len(datafiles) into the function, but if you are insistent upon the data's persistence, you could simply make the function an instance method and store the length of the dataset within the object for which the my_custom_fn is a method.

Unfortunately, as an iterator it does not store anything, it generates the data dynamically. However, as found in the source code for TensorFlow, there is a 'private' variable _batch_size which stores the batch size. You can review the source code here: TensorFlow source.

like image 114
Philip DiSarro Avatar answered Nov 20 '22 03:11

Philip DiSarro