What's the best way to take a single pass over a dataset, in order to evaluate on test data? I'd like to avoid scripting the data loading in python and using a feed_dict
. Instead, I'd like to use all the nice TF infrastructure for queueing, batching, etc.
In the cifar example, the number of test examples is hard-coded and the code takes num_test_examples/batch_size
steps in order to do evaluation. It seems like there should be a better way to do this, using the batching infrastructure.
It seems that the standard pattern is to stop running when you catch some exception thrown by the queue. I've tried some things, such that the queue complains when there are no more examples to populate the queue (ie the Producer can't produce any more). This isn't the exception you want to catch. You want to catch when the consumer has nothing left to consume, ie the queue is empty. How do I do this?
Also, what do you do if the number of test examples isn't divisible by a batch size (eg. the number of test examples is prime).
Additional information:
In practice, we typically evaluate on test data multiple times during learning, by calling a do_evaluation()
function. Yaroslav's answer is useful if you only want process the test data once. Ideally, each call to do_evaluation would run over every example in the test dataset exactly once. We need some mechanism for resetting the batcher so that you can take a single pass over it one more time. Here's some code for that. Don't use the limit_epochs
command. It takes a batcher that doesn't shuffle and specify the number of batches in the test set (this doesn't work if the number of examples set isn't divisible by minibatchsize). The function returns a new op for grabbing data that will throw a tf.errors.OutOfRangeError
when you've run over the whole set. The second return value is an op that should be called to reset the counter. This should be the first call inside a do_evaluation()
function.
def single_pass(source_batcher,num_batches):
zero = tf.constant(0, dtype=tf.int64)
batch_count = tf.Variable(zero, name="epochs", trainable=False)
limiter = tf.count_up_to(batch_count,num_batches)
with tf.control_dependencies([limiter]):
batcher = tf.identity(source_batcher)
reset = tf.assign(batch_count, zero)
return batcher, reset
To iterate through a tensor in Python, we can easily use the for loop method and it will iterate through the tensor directly. To iterate over tensor defines that we have to print a new line tensor and also it will return the number of elements in the tensor.
tensorflow_datasets ( tfds ) defines a collection of datasets ready-to-use with TensorFlow. Each dataset is defined as a tfds. core.
public final class PrefetchDataset. Creates a dataset that asynchronously prefetches elements from `input_dataset`.
Dataset class is used for repeating the tensors for a given count times in dataset. If repeat(count=None) or repeat(count=-1) is specified than dataset is repeated indefinitely.
You can use the tf.Data API for this. Like so.
import tensorflow as tf
graph = tf.Graph()
sess = tf.Session(graph=graph)
def build_dataset(train_or_test):
if train_or_test == 'test':
dataset = tf.data.Dataset.from_tensor_slices(tf.zeros([4, 2]))
elif train_or_test == 'train':
dataset = tf.data.Dataset.from_tensor_slices(tf.ones([10, 2]))
else:
raise ValueError('wrong name')
batch_size = 3
dataset = dataset.batch(batch_size)
return dataset
def build_inputs():
train_dataset = build_dataset('train')
test_dataset = build_dataset('test')
iterator = tf.data.Iterator.from_structure(
train_dataset.output_types,
train_dataset.output_shapes,)
data = iterator.get_next()
tf.identity(data, name='data')
iterator.make_initializer(train_dataset, name='train_init')
iterator.make_initializer(test_dataset, name='test_init')
def model(inputs):
return tf.add(inputs, 1, name='output')
def build_graph():
with graph.as_default():
build_inputs()
data = graph.get_tensor_by_name('data:0')
model(data)
def train():
train_init = graph.get_operation_by_name('train_init')
sess.run(train_init)
out = graph.get_tensor_by_name('output:0')
while True:
try:
network_out = sess.run(out)
print(network_out.shape)
print(network_out)
except tf.errors.OutOfRangeError:
break
def test():
test_init = graph.get_operation_by_name('test_init')
sess.run(test_init)
out = graph.get_tensor_by_name('output:0')
while True:
try:
network_out = sess.run(out)
print(network_out.shape)
print(network_out)
except tf.errors.OutOfRangeError:
break
def train_loop():
cur_epoch = 0
while cur_epoch < 1:
print('Test epoch')
test()
print('Train epoch')
train()
cur_epoch += 1
def initialise_graph():
with graph.as_default():
sess.run(tf.global_variables_initializer())
build_graph()
initialise_graph()
train_loop()
This will output:
Test epoch
(3, 2)
[[1. 1.]
[1. 1.]
[1. 1.]]
(1, 2)
[[1. 1.]]
Train epoch
(3, 2)
[[2. 2.]
[2. 2.]
[2. 2.]]
(3, 2)
[[2. 2.]
[2. 2.]
[2. 2.]]
(3, 2)
[[2. 2.]
[2. 2.]
[2. 2.]]
(1, 2)
[[2. 2.]]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With