Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to iterate a dataset several times using TensorFlow's Dataset API?

How to output the value in a dataset several times? (dataset is created by Dataset API of TensorFlow)

import tensorflow as tf

dataset = tf.contrib.data.Dataset.range(100)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()
epoch = 10

for i in range(epoch):
   for j in range(100):
      value = sess.run(next_element)
      assert j == value
      print(j)

Error message:

tensorflow.python.framework.errors_impl.OutOfRangeError: End of sequence
 [[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[]], output_types=[DT_INT64], _device="/job:localhost/replica:0/task:0/cpu:0"](OneShotIterator)]]

How to make this work?

like image 319
void Avatar asked Nov 02 '17 04:11

void


People also ask

What does tf data dataset From_tensor_slices do?

Dataset. from_tensor_slices() method, we can get the slices of an array in the form of objects by using tf. data.

What is As_numpy_iterator?

as_numpy_iterator() returns dicts instead of numpy arrays - Stack Overflow. Stack Overflow for Teams – Start collaborating and sharing organizational knowledge.

What does dataset prefetch do?

Prefetching. Prefetching overlaps the preprocessing and model execution of a training step. While the model is executing training step s , the input pipeline is reading the data for step s+1 . Doing so reduces the step time to the maximum (as opposed to the sum) of the training and the time it takes to extract the data ...


2 Answers

First of all I advice you to read Data Set Guide. There is described all the details of DataSet API.

Your question is about iterating over the data several times. Here are two solutions for that:

  1. Iterating all epochs at once, no information about end of individual epochs
import tensorflow as tf

epoch   = 10
dataset = tf.data.Dataset.range(100)
dataset = dataset.repeat(epoch)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
sess = tf.Session()

num_batch = 0
j = 0
while True:
    try:
        value = sess.run(next_element)
        assert j == value
        j += 1
        num_batch += 1
        if j > 99: # new epoch
            j = 0
    except tf.errors.OutOfRangeError:
        break

print ("Num Batch: ", num_batch)
  1. Second option inform you about ending each of epoch, so you can ex. check validation loss:
import tensorflow as tf

epoch = 10
dataset = tf.data.Dataset.range(100)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
sess = tf.Session()

num_batch = 0

for e in range(epoch):
    print ("Epoch: ", e)
    j = 0
    sess.run(iterator.initializer)
    while True:
        try:
            value = sess.run(next_element)
            assert j == value
            j += 1
            num_batch += 1
        except tf.errors.OutOfRangeError:
            break

print ("Num Batch: ", num_batch)
like image 153
melgor89 Avatar answered Nov 02 '22 10:11

melgor89


If your tensorflow version is 1.3+, I recommend the high-level API tf.train.MonitoredTrainingSession. The sess created by this API can automatically detect tf.errors.OutOfRangeError with sess.should_stop(). For most of training situations, you need to shuffle data and get a batch each step, I have added these in the following code.

import tensorflow as tf

epoch = 10
dataset = tf.data.Dataset.range(100)
dataset = dataset.shuffle(buffer_size=100) # comment this line if you don't want to shuffle data
dataset = dataset.batch(batch_size=32)     # batch_size=1 if you want to get only one element per step
dataset = dataset.repeat(epoch)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

num_batch = 0
with tf.train.MonitoredTrainingSession() as sess:
    while not sess.should_stop():
        value = sess.run(next_element)
        num_batch += 1
        print("Num Batch: ", num_batch)
like image 24
Tom Avatar answered Nov 02 '22 11:11

Tom