Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Loading two models from Saver in the same Tensorflow session

I have two networks: a Model which generates output and an Adversary which grades the output.

Both have been trained separately but now I need to combine their outputs during a single session.

I've attempted to implement the solution proposed in this post: Run multiple pre-trained Tensorflow nets at the same time

My code

with tf.name_scope("model"):
    model = Model(args)
with tf.name_scope("adv"):
    adversary = Adversary(adv_args)

#...

with tf.Session() as sess:
    tf.global_variables_initializer().run()

    # Get the variables specific to the `Model`
    # Also strip out the surperfluous ":0" for some reason not saved in the checkpoint
    model_varlist = {v.name.lstrip("model/")[:-2]: v 
                     for v in tf.global_variables() if v.name[:5] == "model"}
    model_saver = tf.train.Saver(var_list=model_varlist)
    model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
    model_saver.restore(sess, model_ckpt.model_checkpoint_path)

    # Get the variables specific to the `Adversary`
    adv_varlist = {v.name.lstrip("avd/")[:-2]: v 
                   for v in tf.global_variables() if v.name[:3] == "adv"}
    adv_saver = tf.train.Saver(var_list=adv_varlist)
    adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
    adv_saver.restore(sess, adv_ckpt.model_checkpoint_path)

The problem

The call to the function model_saver.restore() appears to be doing nothing. In another module I use a saver with tf.train.Saver(tf.global_variables()) and it restores the checkpoint fine.

The model has model.tvars = tf.trainable_variables(). To check what was happening I used sess.run() to extract the tvars before and after restore. Each time the initial randomly assigned variables are being used and the variables from the checkpoint are not being assigned.

Any thoughts on why model_saver.restore() appears to be doing nothing?

like image 839
TheCriticalImperitive Avatar asked Jan 12 '17 07:01

TheCriticalImperitive


2 Answers

Solving this problem took a long time so I'm posting my likely imperfect solution in case anyone else needs it.

To diagnose the problem I manually looped through each of the variables and assigned them one by one. Then I noticed that after assigning the variable the name would change. This is described here: TensorFlow checkpoint save and read

Based on the advice in that post I ran each of the models in their own graphs. It also means that I had to run each graph in its own session. This meant handling the session management differently.

First I created two graphs

model_graph = tf.Graph()
with model_graph.as_default():
    model = Model(args)

adv_graph = tf.Graph()
with adv_graph.as_default():
    adversary = Adversary(adv_args)

Then two sessions

adv_sess = tf.Session(graph=adv_graph)
sess = tf.Session(graph=model_graph)

Then I initialised the variables in each session and restored each graph separately

with sess.as_default():
    with model_graph.as_default():
        tf.global_variables_initializer().run()
        model_saver = tf.train.Saver(tf.global_variables())
        model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
        model_saver.restore(sess, model_ckpt.model_checkpoint_path)

with adv_sess.as_default():
    with adv_graph.as_default():
        tf.global_variables_initializer().run()
        adv_saver = tf.train.Saver(tf.global_variables())
        adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
        adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path)

From here whenever each session was needed I would wrap any tf functions in that session with with sess.as_default():. At the end I manually close the sessions

sess.close()
adv_sess.close()
like image 103
TheCriticalImperitive Avatar answered Nov 17 '22 11:11

TheCriticalImperitive


The answer marked as correct does not tell us how to load two different models into one session explicitly, here is my answer:

  1. create two different name scopes for the models you want to load.

  2. initialize two savers which are going to load parameters for variables in the two different networks.

  3. load from the corresponding checkpoint files.

with tf.Session() as sess:
    with tf.name_scope("net1"):
      net1 = Net1()
    with tf.name_scope("net2"):
      net2 = Net2()

    net1_varlist = {v.op.name.lstrip("net1/"): v
                    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}
    net1_saver = tf.train.Saver(var_list=net1_varlist)

    net2_varlist = {v.op.name.lstrip("net2/"): v
                    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}
    net2_saver = tf.train.Saver(var_list=net2_varlist)

    net1_saver.restore(sess, "net1.ckpt")
    net2_saver.restore(sess, "net2.ckpt")
like image 1
shinxg Avatar answered Nov 17 '22 10:11

shinxg