Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Can't use estimator + dataset and train for less than one epoch

TensorFlow 1.4 moves TF Dataset to core (tf.data.Dataset) and doc/tutorial suggest to use tf.estimator to train models.

However, as recommended at the end of this page, the Dataset object and its iterator must be instantiated inside the input_fn function. This means the iterations through the dataset will start over for each call to estimator.train(input_fn, steps). Thus, calling is with steps < number of samples in epoch, will lead to train the model on a subset of the dataset.

Thus my question. Is it possible to implement something like this with Estimator + Dataset:

for i in range(num_epochs):
    # Train for some steps
    estimator.train(input_fn=train_input_fn, steps=valid_freq)

    validation_iterator.
    # Evaluate on the validation set (steps=None, we evaluate on the full validation set)
    estimator.evaluate(input_fn=valid_input_fn)

without starting training samples iterations from scratch at each call to estimator.train(input_fn=train_input_fn, steps=valid_freq)?

For example, unlike here, instantiate the Dataset and its iterator outside input_fn? I tried it but it does not work because then the input (from the dataset iterator) and the model (from the estimator model_fn) are not part of the same graph.

Thanks

Related GitHub issue

like image 260
nisace Avatar asked Nov 07 '17 09:11

nisace


1 Answers

I don't know any way to make the training consistent across runs of estimator.train().

However what you can do is make sure that you build the train_input_fn such that it will be random enough to obtain the same effect.


For instance suppose you have a dataset of values [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] and you can only train on half the dataset at each call of estimator.train.
If you don't shuffle well enough, you will keep training on values [0, 1, 2, 3, 4]:

train_size = 10
dataset = tf.data.Dataset.range(train_size)
x = dataset.make_one_shot_iterator().get_next()

sess = tf.Session()
for i in range(train_size // 2):
    print(sess.run(x))

However, if you call tf.data.Dataset.shuffle() with a buffer_size at least as large as the dataset, you will get random values. Calling multiple times estimator.train with this will be equivalent to calling it one time with multiple epochs.

train_size = 10
dataset = tf.data.Dataset.range(train_size)
dataset = dataset.shuffle(buffer_size=train_size)
x = dataset.make_one_shot_iterator().get_next()

sess = tf.Session()
for i in range(train_size // 2):
    print(sess.run(x))

I wrote another answer to explain the importance of buffer_size here.

like image 197
Olivier Moindrot Avatar answered Sep 21 '22 22:09

Olivier Moindrot