Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use tf.train.MonitoredTrainingSession to restore only certain variables

Tags:

tensorflow

How does one tell a tf.train.MonitoredTrainingSession to restore only a subset of the variables, and perform intialization on the rest?

Starting with the cifar10 tutorial .. https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_train.py

.. I created lists of the variables to restore and initialize, and specified them using a Scaffold that I pass to the MonitoredTrainingSession:

  restoration_saver = Saver(var_list=restore_vars)
  restoration_scaffold = Scaffold(init_op=variables_initializer(init_vars),
                                  ready_op=constant([]),
                                  saver=restoration_saver)

but this gives the following error:

RuntimeError: Init operations did not make model ready for local_init. Init op: group_deps, init fn: None, error: Variables not initialized: conv2a/T, conv2b/T, [...]

.. where the uninitialized variables listed in the error message are the variables in my "init_vars" list.

The exception is raised by SessionManager.prepare_session(). The source code for that method seems to indicate that if the session is restored from a checkpoint, then the init_op is not run. So it looks like you can either have restored variables or initialized variables, but not both.

like image 807
user550701 Avatar asked Feb 05 '23 16:02

user550701


2 Answers

OK so as I suspected, I got what I wanted by implementing a new RefinementSessionManager class based on the existing tf.training.SessionManager. The two classes are almost identical, except I modified the prepare_session method to call the init_op regardless of whether the model was loaded from a checkpoint.

This allows me to load a list of variables from the checkpoint and initialize the remaining variables in the init_op.

My prepare_session method is this:

  def prepare_session(self, master, init_op=None, saver=None,
                  checkpoint_dir=None, wait_for_checkpoint=False,
                  max_wait_secs=7200, config=None, init_feed_dict=None,
                  init_fn=None):

    sess, is_loaded_from_checkpoint = self._restore_checkpoint(
    master,
    saver,
    checkpoint_dir=checkpoint_dir,
    wait_for_checkpoint=wait_for_checkpoint,
    max_wait_secs=max_wait_secs,
    config=config)

    # [removed] if not is_loaded_from_checkpoint:
    # we still want to run any supplied initialization on models that
    # were loaded from checkpoint.

    if not is_loaded_from_checkpoint and init_op is None and not init_fn and self._local_init_op is None:
      raise RuntimeError("Model is not initialized and no init_op or "
                     "init_fn or local_init_op was given")
    if init_op is not None:
      sess.run(init_op, feed_dict=init_feed_dict)
    if init_fn:
      init_fn(sess)

    # [...]

Hope this helps somebody else.

like image 84
user550701 Avatar answered Feb 08 '23 17:02

user550701


The hint from @avital works, to be more complete: pass a scaffolding object into MonitoredTrainingSession with a local_init_op and a ready_for_local_init_op. Like so:

model_ready_for_local_init_op = tf.report_uninitialized_variables(
            var_list=var_list)
model_init_tmp_vars = tf.variables_initializer(var_list)
scaffold = tf.train.Scaffold(saver=model_saver,
               local_init_op = model_init_tmp_vars,
               ready_for_local_init_op = model_ready_for_local_init_op)
with tf.train.MonitoredTrainingSession(...,
                scaffold=scaffold,
                ...) as mon_sess:
   ...
like image 29
Bastiaan Avatar answered Feb 08 '23 15:02

Bastiaan