Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Epoch counter with TensorFlow Dataset API

I'm changing my TensorFlow code from the old queue interface to the new Dataset API. In my old code I kept track of the epoch count by incrementing a tf.Variable every time a new input tensor is accessed and processed in the queue. I'd like to have this epoch count with the new Dataset API, but I'm having some trouble making it work.

Since I'm producing a variable amount of data items in the pre-processing stage, it is not a simple matter of incrementing a (Python) counter in the training loop - I need to compute the epoch count with respect to the input of the queues or Dataset.

I mimicked what I had before with the old queue system, and here is what I ended up with for the Dataset API (simplified example):

with tf.Graph().as_default():

    data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
    input_tensors = (data,)

    epoch_counter = tf.Variable(initial_value=0.0, dtype=tf.float32,
                                trainable=False)

    def pre_processing_func(data_):
        data_size = tf.constant(0.1, dtype=tf.float32)
        epoch_counter_op = tf.assign_add(epoch_counter, data_size)
        with tf.control_dependencies([epoch_counter_op]):
            # normally I would do data-augmentation here
            results = (tf.expand_dims(data_, axis=0),)
            return tf.data.Dataset.from_tensor_slices(results)

    dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
    dataset = dataset_source.flat_map(pre_processing_func)
    dataset = dataset.repeat()
    # ... do something with 'dataset' and print
    # the value of 'epoch_counter' every once a while

However, this doesn't work. It crashes with a cryptic error message:

 TypeError: In op 'AssignAdd', input types ([tf.float32, tf.float32])
 are not compatible with expected types ([tf.float32_ref, tf.float32])

Closer inspection shows that the epoch_counter variable might not be accessible within the pre_processing_func at all. Does it live in a different graph perhaps?

Any idea how to fix the above example? Or how to get the epoch counter (with decimal points, e.g. 0.4 or 2.9) through some other means?

like image 499
CNugteren Avatar asked Nov 21 '17 10:11

CNugteren


2 Answers

TL;DR: Replace the definition of epoch_counter with the following:

epoch_counter = tf.get_variable("epoch_counter", initializer=0.0,
                                trainable=False, use_resource=True)

There are some limitations around using TensorFlow variables inside tf.data.Dataset transformations. The principle limitation is that all variables must be "resource variables" and not the older "reference variables"; unfortunately tf.Variable still creates "reference variables" for backwards compatibility reasons.

Generally speaking, I wouldn't recommend using variables in a tf.data pipeline if it's possible to avoid it. For example, you might be able to use Dataset.range() to define an epoch counter, and then do something like:

epoch_counter = tf.data.Dataset.range(NUM_EPOCHS)
dataset = epoch_counter.flat_map(lambda i: tf.data.Dataset.zip(
    (pre_processing_func(data), tf.data.Dataset.from_tensors(i).repeat()))

The above snippet attaches an epoch counter to every value as a second component.

like image 178
mrry Avatar answered Oct 26 '22 00:10

mrry


To add to @mrry's great answer, if you want to stay within the tf.data pipeline and also want to track the iteration within each epoch you can try my solution below. If you have non-unit batch size I guess you would have to add the line data = data.batch(bs).

import tensorflow as tf
import itertools

def step_counter(): 
    for i in itertools.count(): yield i

num_examples = 3
num_epochs = 2
num_iters = num_examples * num_epochs

features = tf.data.Dataset.range(num_examples)
labels = tf.data.Dataset.range(num_examples)
data = tf.data.Dataset.zip((features, labels))
data = data.shuffle(num_examples)

step = tf.data.Dataset.from_generator(step_counter, tf.int32)
data = tf.data.Dataset.zip((data, step))

epoch = tf.data.Dataset.range(num_epochs)
data = epoch.flat_map(
    lambda i: tf.data.Dataset.zip(
        (data, tf.data.Dataset.from_tensors(i).repeat())))

data = data.repeat(num_epochs)
it = data.make_one_shot_iterator()
example = it.get_next()

with tf.Session() as sess:
    for _ in range(num_iters):
        ((x, y), st), ep = sess.run(example)
        print(f'step {st} \t epoch {ep} \t x {x} \t y {y}')

Prints:

step 0   epoch 0     x 2     y 2
step 1   epoch 0     x 0     y 0
step 2   epoch 0     x 1     y 1
step 0   epoch 1     x 2     y 2
step 1   epoch 1     x 0     y 0
step 2   epoch 1     x 1     y 1
like image 1
numerica Avatar answered Oct 25 '22 23:10

numerica