Is it possible to restore a variable, only if present? What is the most idiomatic way of doing so?
For instance, consider the following minimal example:
import tensorflow as tf
import glob
import sys
import os
with tf.variable_scope('volatile'):
x = tf.get_variable('x', initializer=0)
with tf.variable_scope('persistent'):
y = tf.get_variable('y', initializer=0)
add1 = tf.assign_add(y, 1)
saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, 'persistent'))
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
tf.get_default_graph().finalize()
print('save file', sys.argv[1])
if glob.glob(sys.argv[1] + '*'):
saver.restore(sess, sys.argv[1])
print(sess.run(y))
sess.run(add1)
print(sess.run(y))
saver.save(sess, sys.argv[1])
When run twice with the same argument, the program first prints 0\n1 and then 1\n2 as expected. Now suppose you update your code to have new functionality, by adding a z = tf.get_variable('z', initializer=0) after add1 within the persistent scope. Running this again when the old save file is present will break with the following:
NotFoundError (see above for traceback): Key persistent/z not found in checkpoint
[[Node: save/RestoreV2_1 = RestoreV2[dtypes=[DT_INT32],
_device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0,
save/RestoreV2_1/tensor_names,
save/RestoreV2_1/shape_and_slices)]]
[[Node: save/Assign_1/_18 = _Recv[client_terminated=false,
recv_device="/job:localhost/replica:0/task:0/device:GPU:0",
send_device="/job:localhost/replica:0/task:0/device:CPU:0",
send_device_incarnation=1,
tensor_name="edge_12_save/Assign_1",
tensor_type=DT_FLOAT,
_device="/job:localhost/replica:0/task:0/device:GPU:0"]()]]
You can use the following function for restoring (taken from here):
def optimistic_restore(session, save_file, graph=tf.get_default_graph()):
reader = tf.train.NewCheckpointReader(save_file)
saved_shapes = reader.get_variable_to_shape_map()
var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables()
if var.name.split(':')[0] in saved_shapes])
restore_vars = []
for var_name, saved_var_name in var_names:
curr_var = graph.get_tensor_by_name(var_name)
var_shape = curr_var.get_shape().as_list()
if var_shape == saved_shapes[saved_var_name]:
restore_vars.append(curr_var)
opt_saver = tf.train.Saver(restore_vars)
opt_saver.restore(session, save_file)
I usually run a sess.run(tf.global_variables_initializer()) to make sure all variables are initialized, and then I run optimistic_restore(sess,...) to restore the variables which can be restored.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With