Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use Tensorflow dataset API with training and validation sets

Simple task at hand: run training for N epochs performing calculating exact validation accuracy after each epoch. Epoch size can be either equal to full training set or some predefined number of iterations. During validation every validation set input has to be evaluated exactly once.

What would be the best way to mix together one_shot_iterators, initializable iterator and/or handle for that task?

Here is scaffolding of how i see it should work:

def build_training_dataset():
    pass

def build_validation_dataset():
    pass

def construct_train_op(dataset):
    pass

def magic(iterator):
    pass

USE_CUSTOM_EPOCH_SIZE = True
CUSTOM_EPOCH_SIZE = 60
MAX_EPOCHS = 100


training_dataset = build_training_dataset()
validation_dataset = build_validation_dataset()


# Magic goes here to build a nice one-instance dataset
dataset = magic(training_dataset, validation_dataset)

train_op = construct_train_op(dataset)

# Run N epochs in which the training dataset is traversed, followed by the
# validation dataset.
with tf.Session() as sess:
    for epoch in MAX_EPOCHS:

        # train
        if USE_CUSTOM_EPOCH_SIZE:
            for _ in range(CUSTOM_EPOCH_SIZE):
                sess.run(train_op)
        else:
            while True:
                # I guess smth like this:
                try:
                    sess.run(train_op)
                except tf.errors.OutOfRangeError:
                    break # we are done with the epoch

        # validation
        validation_predictions = []
        while True:
            try:
                np.append(validation_predictions, sess.run(train_op)) # but for validation this time
            except tf.errors.OutOfRangeError:
                print('epoch %d finished with accuracy: %f' % (epoch validation_predictions.mean()))
                break 
like image 887
y.selivonchyk Avatar asked Nov 17 '17 18:11

y.selivonchyk


People also ask

How does TensorFlow use validation data?

TensorFlow Data Validation identifies any anomalies in the input data by comparing data statistics against a schema. The schema codifies properties which the input data is expected to satisfy, such as data types or categorical values, and can be modified or replaced by the user.


1 Answers

Since the solution is a lot messier than I expected it comes in 2 peaces:

0) Auxiliary code shared by both examples:

USE_CUSTOM_EPOCH_SIZE = True
CUSTOM_EPOCH_SIZE = 60
MAX_EPOCHS = 100

TRAIN_SIZE = 500
VALIDATION_SIZE = 145
BATCH_SIZE = 64


def construct_train_op(batch):
    return batch


def build_train_dataset():
    return tf.data.Dataset.range(TRAIN_SIZE) \
        .map(lambda x: x + tf.random_uniform([], -10, 10, tf.int64)) \
        .batch(BATCH_SIZE)

def build_test_dataset():
    return tf.data.Dataset.range(VALIDATION_SIZE) \
        .batch(BATCH_SIZE)

1) For epoch equal to the train dataset size:

# datasets construction
training_dataset = build_train_dataset()
validation_dataset = build_test_dataset()

# handle constructions. Handle allows us to feed data from different dataset by providing a parameter in feed_dict
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()

train_op = construct_train_op(next_element)

training_iterator = training_dataset.make_initializable_iterator()
validation_iterator = validation_dataset.make_initializable_iterator()

with tf.Session() as sess:
    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())

    for epoch in range(MAX_EPOCHS):
        #train
        sess.run(training_iterator.initializer)
        total_in_train = 0
        while True:
            try:
                train_output = sess.run(train_op, feed_dict={handle: training_handle})
                total_in_train += len(train_output)
            except tf.errors.OutOfRangeError:
                assert total_in_train == TRAIN_SIZE
                break # we are done with the epoch

        # validation
        validation_predictions = []
        sess.run(validation_iterator.initializer)
        while True:
            try:
                pred = sess.run(train_op, feed_dict={handle: validation_handle})
                validation_predictions = np.append(validation_predictions, pred)
            except tf.errors.OutOfRangeError:
                assert len(validation_predictions) == VALIDATION_SIZE
                print('Epoch %d finished with accuracy: %f' % (epoch, np.mean(validation_predictions)))
                break

2) For custom epoch size:

# datasets construction
training_dataset = build_train_dataset().repeat() # CHANGE 1
validation_dataset = build_test_dataset()

# handle constructions. Handle allows us to feed data from different dataset by providing a parameter in feed_dict
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(handle, training_dataset.output_types, training_dataset.output_shapes)
next_element = iterator.get_next()


train_op = construct_train_op(next_element)

training_iterator = training_dataset.make_one_shot_iterator() # CHANGE 2
validation_iterator = validation_dataset.make_initializable_iterator()

with tf.Session() as sess:
    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())

    for epoch in range(MAX_EPOCHS):
        #train
        # CHANGE 3: no initiazation, not try/catch
        for _ in range(CUSTOM_EPOCH_SIZE): 
            train_output = sess.run(train_op, feed_dict={handle: training_handle})


        # validation
        validation_predictions = []
        sess.run(validation_iterator.initializer)
        while True:
            try:
                pred = sess.run(train_op, feed_dict={handle: validation_handle})
                validation_predictions = np.append(validation_predictions, pred)
            except tf.errors.OutOfRangeError:
                assert len(validation_predictions) == VALIDATION_SIZE
                print('Epoch %d finished with accuracy: %f' % (epoch, np.mean(validation_predictions)))
                break
like image 79
y.selivonchyk Avatar answered Nov 02 '22 04:11

y.selivonchyk