How does one tell a tf.train.MonitoredTrainingSession to restore only a subset of the variables, and perform intialization on the rest?
Starting with the cifar10 tutorial .. https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10/cifar10_train.py
.. I created lists of the variables to restore and initialize, and specified them using a Scaffold that I pass to the MonitoredTrainingSession:
restoration_saver = Saver(var_list=restore_vars)
restoration_scaffold = Scaffold(init_op=variables_initializer(init_vars),
ready_op=constant([]),
saver=restoration_saver)
but this gives the following error:
RuntimeError: Init operations did not make model ready for local_init. Init op: group_deps, init fn: None, error: Variables not initialized: conv2a/T, conv2b/T, [...]
.. where the uninitialized variables listed in the error message are the variables in my "init_vars" list.
The exception is raised by SessionManager.prepare_session(). The source code for that method seems to indicate that if the session is restored from a checkpoint, then the init_op is not run. So it looks like you can either have restored variables or initialized variables, but not both.
OK so as I suspected, I got what I wanted by implementing a new RefinementSessionManager class based on the existing tf.training.SessionManager. The two classes are almost identical, except I modified the prepare_session method to call the init_op regardless of whether the model was loaded from a checkpoint.
This allows me to load a list of variables from the checkpoint and initialize the remaining variables in the init_op.
My prepare_session method is this:
def prepare_session(self, master, init_op=None, saver=None,
checkpoint_dir=None, wait_for_checkpoint=False,
max_wait_secs=7200, config=None, init_feed_dict=None,
init_fn=None):
sess, is_loaded_from_checkpoint = self._restore_checkpoint(
master,
saver,
checkpoint_dir=checkpoint_dir,
wait_for_checkpoint=wait_for_checkpoint,
max_wait_secs=max_wait_secs,
config=config)
# [removed] if not is_loaded_from_checkpoint:
# we still want to run any supplied initialization on models that
# were loaded from checkpoint.
if not is_loaded_from_checkpoint and init_op is None and not init_fn and self._local_init_op is None:
raise RuntimeError("Model is not initialized and no init_op or "
"init_fn or local_init_op was given")
if init_op is not None:
sess.run(init_op, feed_dict=init_feed_dict)
if init_fn:
init_fn(sess)
# [...]
Hope this helps somebody else.
The hint from @avital works, to be more complete: pass a scaffolding object into MonitoredTrainingSession
with a local_init_op
and a ready_for_local_init_op
. Like so:
model_ready_for_local_init_op = tf.report_uninitialized_variables(
var_list=var_list)
model_init_tmp_vars = tf.variables_initializer(var_list)
scaffold = tf.train.Scaffold(saver=model_saver,
local_init_op = model_init_tmp_vars,
ready_for_local_init_op = model_ready_for_local_init_op)
with tf.train.MonitoredTrainingSession(...,
scaffold=scaffold,
...) as mon_sess:
...
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With