Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Restore subset of variables in Tensorflow

I am training a Generative Adversarial Network (GAN) in tensorflow, where basically we have two different networks each one with its own optimizer.

self.G, self.layer = self.generator(self.inputCT,batch_size_tf)
self.D, self.D_logits = self.discriminator(self.GT_1hot)

...

self.g_optim = tf.train.MomentumOptimizer(self.learning_rate_tensor, 0.9).minimize(self.g_loss, global_step=self.global_step)

self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) \
                      .minimize(self.d_loss, var_list=self.d_vars)

The problem is that I train one of the networks (g) first, and then, I want to train g and d together. However, when I call the load function:

self.sess.run(tf.initialize_all_variables())
self.sess.graph.finalize()

self.load(self.checkpoint_dir)

def load(self, checkpoint_dir):
    print(" [*] Reading checkpoints...")

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        self.saver.restore(self.sess, ckpt.model_checkpoint_path)
        return True
    else:
        return False

I have an error like this (with a lot more traceback):

Tensor name "beta2_power" not found in checkpoint files checkpoint/MR2CT.model-96000

I can restore the g network and keep training with that function, but when I want to star d from scratch, and g from the the stored model I have that error.

like image 854
Roger Trullo Avatar asked Jan 12 '17 19:01

Roger Trullo


3 Answers

To restore a subset of variables, you must create a new tf.train.Saver and pass it a specific list of variables to restore in the optional var_list argument.

By default, a tf.train.Saver will create ops that (i) save every variable in your graph when you call saver.save() and (ii) lookup (by name) every variable in the given checkpoint when you call saver.restore(). While this works for most common scenarios, you have to provide more information to work with specific subsets of the variables:

  1. If you only want to restore a subset of the variables, you can get a list of these variables by calling tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=G_NETWORK_PREFIX), assuming that you put the "g" network in a common with tf.name_scope(G_NETWORK_PREFIX): or tf.variable_scope(G_NETWORK_PREFIX): block. You can then pass this list to the tf.train.Saver constructor.

  2. If you want to restore a subset of the variable and/or they variables in the checkpoint have different names, you can pass a dictionary as the var_list argument. By default, each variable in a checkpoint is associated with a key, which is the value of its tf.Variable.name property. If the name is different in the target graph (e.g. because you added a scope prefix), you can specify a dictionary that maps string keys (in the checkpoint file) to tf.Variable objects (in the target graph).

like image 181
mrry Avatar answered Oct 09 '22 04:10

mrry


I had a similar problem when restoring only part of my variables from a checkpoint and some of the saved variables did not exist in the new model. Inspired by @Lidong answer I modified a little the reading function:

def get_tensors_in_checkpoint_file(file_name,all_tensors=True,tensor_name=None):
varlist=[]
var_value =[]
reader = pywrap_tensorflow.NewCheckpointReader(file_name)
if all_tensors:
  var_to_shape_map = reader.get_variable_to_shape_map()
  for key in sorted(var_to_shape_map):
    varlist.append(key)
    var_value.append(reader.get_tensor(key))
else:
    varlist.append(tensor_name)
    var_value.append(reader.get_tensor(tensor_name))
return (varlist, var_value)

and added a loading function:

def build_tensors_in_checkpoint_file(loaded_tensors):
full_var_list = list()
# Loop all loaded tensors
for i, tensor_name in enumerate(loaded_tensors[0]):
    # Extract tensor
    try:
        tensor_aux = tf.get_default_graph().get_tensor_by_name(tensor_name+":0")
    except:
        print('Not found: '+tensor_name)
    full_var_list.append(tensor_aux)
return full_var_list

Then you can simply load all common variables using:

CHECKPOINT_NAME = path to save file
restored_vars  = get_tensors_in_checkpoint_file(file_name=CHECKPOINT_NAME)
tensors_to_load = build_tensors_in_checkpoint_file(restored_vars)
loader = tf.train.Saver(tensors_to_load)
loader.restore(sess, CHECKPOINT_NAME)

Edit: I am using tensorflow 1.2

like image 33
Ramiro R.C. Avatar answered Oct 09 '22 04:10

Ramiro R.C.


Inspired by @mrry, I propose a solution for this problem. To make it clear, I formulate the problem as restoring a subset of the variable from the checkpoint, when the model is built on a pre-trained model. First, we should use print_tensors_in_checkpoint_file function from the library inspect_checkpoint or just simply extract this function by:

from tensorflow.python import pywrap_tensorflow
def print_tensors_in_checkpoint_file(file_name, tensor_name, all_tensors):
    varlist=[]
    reader = pywrap_tensorflow.NewCheckpointReader(file_name)
    if all_tensors:
      var_to_shape_map = reader.get_variable_to_shape_map()
      for key in sorted(var_to_shape_map):
        varlist.append(key)
    return varlist
varlist=print_tensors_in_checkpoint_file(file_name=the path of the ckpt file,all_tensors=True,tensor_name=None)

Then we use tf.get_collection() just like @mrry saied:

variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

Finally, we can initialize the saver by:

saver = tf.train.Saver(variable[:len(varlist)])

The complete version can be found at my github: https://github.com/pobingwanghai/tensorflow_trick/blob/master/restore_from_checkpoint.py

In my situation, the new variables are added at the end of the model, so I can simply use [:length()] to identify the needed variables, for a more complex situation, you might have to do some hand-alignment work or write a simple string matching function to determine the required variables.

like image 22
Lidong Avatar answered Oct 09 '22 02:10

Lidong