Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Restoring a Tensorflow model that uses Iterators

I have a model that's trains my network using an Iterator; following the new Dataset API pipeline model that's now recommended by Google.

I read tfrecord files, feed data to the network, train nicely, and all is going well, I save my model in the end of the training so I can run Inference on it later. A simplified version of the code is as following:

""" Training and saving """

training_dataset = tf.contrib.data.TFRecordDataset(training_record)
training_dataset = training_dataset.map(ds._path_records_parser)
training_dataset = training_dataset.batch(BATCH_SIZE)
with tf.name_scope("iterators"):
  training_iterator = Iterator.from_structure(training_dataset.output_types, training_dataset.output_shapes)
  next_training_element = training_iterator.get_next()
  training_init_op = training_iterator.make_initializer(training_dataset)

def train(num_epochs):
  # compute for the number of epochs
  for e in range(1, num_epochs+1):
    session.run(training_init_op) #initializing iterator here
    while True:
      try:
        images, labels = session.run(next_training_element)
        session.run(optimizer, feed_dict={x: images, y_true: labels})
      except tf.errors.OutOfRangeError:
        saver_name = './saved_models/ucf-model'
        print("Finished Training Epoch {}".format(e))
        break



    """ Restoring """
# restoring the saved model and its variables
session = tf.Session()
saver = tf.train.import_meta_graph(r'saved_models\ucf-model.meta')
saver.restore(session, tf.train.latest_checkpoint('.\saved_models'))
graph = tf.get_default_graph()

# restoring relevant tensors/ops
accuracy = graph.get_tensor_by_name("accuracy/Mean:0") #the tensor that when evaluated returns the mean accuracy of the batch
testing_iterator = graph.get_operation_by_name("iterators/Iterator") #my iterator used in testing.
next_testing_element = graph.get_operation_by_name("iterators/IteratorGetNext") #the GetNext operator for my iterator
# loading my testing set tfrecords
testing_dataset = tf.contrib.data.TFRecordDataset(testing_record_path)
testing_dataset = testing_dataset.map(ds._path_records_parser, num_threads=4, output_buffer_size=BATCH_SIZE*20)
testing_dataset = testing_dataset.batch(BATCH_SIZE)

testing_init_op = testing_iterator.make_initializer(testing_dataset) #to initialize the dataset

with tf.Session() as session:
  session.run(testing_init_op)
  while True:
    try:
      images, labels = session.run(next_testing_element)
      accuracy = session.run(accuracy, feed_dict={x: test_images, y_true: test_labels}) #error here, x, y_true not defined
    except tf.errors.OutOfRangeError:
      break

My problem is mainly when I restore the model. How to feed testing data to the network?

  • When I restore my Iterator using testing_iterator = graph.get_operation_by_name("iterators/Iterator"), next_testing_element = graph.get_operation_by_name("iterators/IteratorGetNext"), I get the following error: GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element.
  • So I did try to initialize my dataset using: testing_init_op = testing_iterator.make_initializer(testing_dataset)). I got this error: AttributeError: 'Operation' object has no attribute 'make_initializer'

Another issue is, since an iterator is being used, there's no need to use placeholders in the training_model, as an iterator feed data directly to the graph. But this way, how to restore my feed_dict keys in the 3rd to last line, when I feed data to the "accuracy" op?

EDIT: if someone could suggest a way to add placeholders between the Iterator and the network input, then I could try running the graph by evaluating the "accuracy" tensor while feeding data to the placeholders and ignoring the iterator altogether.

like image 483
Omar Hommos Avatar asked Oct 24 '17 18:10

Omar Hommos


2 Answers

When restoring a saved meta graph, you can restore the initialization operation with name and then use it again to initialize the input pipeline for inference.

That is, when creating the graph, you can do

    dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')

And then restore this operation by doing:

    dataset_init_op = graph.get_operation_by_name('dataset_init')

Here is a self contained code snippet that compares results of a randomly initialized model before and after restoring.

Saving an Iterator

np.random.seed(42)
data = np.random.random([4, 4])
X = tf.placeholder(dtype=tf.float32, shape=[4, 4], name='X')
dataset = tf.data.Dataset.from_tensor_slices(X)
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
dataset_next_op = iterator.get_next()

# name the operation
dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')

w = np.random.random([1, 4])
W = tf.Variable(w, name='W', dtype=tf.float32)
output = tf.multiply(W, dataset_next_op, name='output')     
sess = tf.Session()
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
sess.run(dataset_init_op, feed_dict={X:data})
while True:
    try:
        print(sess.run(output))
    except tf.errors.OutOfRangeError:
        saver.save(sess, 'tmp/', global_step=1002)
    break

And then you can restore the same model for inference as follows:

Restoring saved iterator

np.random.seed(42)
data = np.random.random([4, 4])
tf.reset_default_graph()
sess = tf.Session()
saver = tf.train.import_meta_graph('tmp/-1002.meta')
ckpt = tf.train.get_checkpoint_state(os.path.dirname('tmp/checkpoint'))
saver.restore(sess, ckpt.model_checkpoint_path)
graph = tf.get_default_graph()

# Restore the init operation
dataset_init_op = graph.get_operation_by_name('dataset_init')

X = graph.get_tensor_by_name('X:0')
output = graph.get_tensor_by_name('output:0')
sess.run(dataset_init_op, feed_dict={X:data})
while True:
try:
    print(sess.run(output))
except tf.errors.OutOfRangeError:
    break
like image 149
metastableB Avatar answered Nov 02 '22 12:11

metastableB


I would suggest to use tf.contrib.data.make_saveable_from_iterator, which has been designed precisely for this purpose. It is much less verbose and does not require you to change existing code, in particular how you define your iterator.

Working example, when we save everything after step 5 has completed. Note how I don't even bother knowing what seed is used.

import tensorflow as tf

iterator = (
  tf.data.Dataset.range(100)
  .shuffle(10)
  .make_one_shot_iterator())
batch = iterator.get_next(name='batch')

saveable_obj = tf.contrib.data.make_saveable_from_iterator(iterator)
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
saver = tf.train.Saver()

with tf.Session() as sess:
  tf.global_variables_initializer().run()
  for step in range(10):
    print('{}: {}'.format(step, sess.run(batch)))
    if step == 5:
      saver.save(sess, './foo', global_step=step)

# 0: 1
# 1: 6
# 2: 7
# 3: 3
# 4: 8
# 5: 10
# 6: 12
# 7: 14
# 8: 5
# 9: 17

Then later, if we resume from step 6, we get the same output.

import tensorflow as tf

saver = tf.train.import_meta_graph('./foo-5.meta')
with tf.Session() as sess:
  saver.restore(sess, './foo-5')
  for step in range(6, 10):
    print('{}: {}'.format(step, sess.run('batch:0')))
# 6: 12
# 7: 14
# 8: 5
# 9: 17
like image 41
P-Gn Avatar answered Nov 02 '22 10:11

P-Gn