I have two networks: a Model
which generates output and an Adversary
which grades the output.
Both have been trained separately but now I need to combine their outputs during a single session.
I've attempted to implement the solution proposed in this post: Run multiple pre-trained Tensorflow nets at the same time
My code
with tf.name_scope("model"):
model = Model(args)
with tf.name_scope("adv"):
adversary = Adversary(adv_args)
#...
with tf.Session() as sess:
tf.global_variables_initializer().run()
# Get the variables specific to the `Model`
# Also strip out the surperfluous ":0" for some reason not saved in the checkpoint
model_varlist = {v.name.lstrip("model/")[:-2]: v
for v in tf.global_variables() if v.name[:5] == "model"}
model_saver = tf.train.Saver(var_list=model_varlist)
model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
model_saver.restore(sess, model_ckpt.model_checkpoint_path)
# Get the variables specific to the `Adversary`
adv_varlist = {v.name.lstrip("avd/")[:-2]: v
for v in tf.global_variables() if v.name[:3] == "adv"}
adv_saver = tf.train.Saver(var_list=adv_varlist)
adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
adv_saver.restore(sess, adv_ckpt.model_checkpoint_path)
The problem
The call to the function model_saver.restore()
appears to be doing nothing. In another module I use a saver with tf.train.Saver(tf.global_variables())
and it restores the checkpoint fine.
The model has model.tvars = tf.trainable_variables()
. To check what was happening I used sess.run()
to extract the tvars
before and after restore. Each time the initial randomly assigned variables are being used and the variables from the checkpoint are not being assigned.
Any thoughts on why model_saver.restore()
appears to be doing nothing?
Solving this problem took a long time so I'm posting my likely imperfect solution in case anyone else needs it.
To diagnose the problem I manually looped through each of the variables and assigned them one by one. Then I noticed that after assigning the variable the name would change. This is described here: TensorFlow checkpoint save and read
Based on the advice in that post I ran each of the models in their own graphs. It also means that I had to run each graph in its own session. This meant handling the session management differently.
First I created two graphs
model_graph = tf.Graph()
with model_graph.as_default():
model = Model(args)
adv_graph = tf.Graph()
with adv_graph.as_default():
adversary = Adversary(adv_args)
Then two sessions
adv_sess = tf.Session(graph=adv_graph)
sess = tf.Session(graph=model_graph)
Then I initialised the variables in each session and restored each graph separately
with sess.as_default():
with model_graph.as_default():
tf.global_variables_initializer().run()
model_saver = tf.train.Saver(tf.global_variables())
model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
model_saver.restore(sess, model_ckpt.model_checkpoint_path)
with adv_sess.as_default():
with adv_graph.as_default():
tf.global_variables_initializer().run()
adv_saver = tf.train.Saver(tf.global_variables())
adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path)
From here whenever each session was needed I would wrap any tf
functions in that session with with sess.as_default():
. At the end I manually close the sessions
sess.close()
adv_sess.close()
The answer marked as correct does not tell us how to load two different models into one session explicitly, here is my answer:
create two different name scopes for the models you want to load.
initialize two savers which are going to load parameters for variables in the two different networks.
load from the corresponding checkpoint files.
with tf.Session() as sess:
with tf.name_scope("net1"):
net1 = Net1()
with tf.name_scope("net2"):
net2 = Net2()
net1_varlist = {v.op.name.lstrip("net1/"): v
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")}
net1_saver = tf.train.Saver(var_list=net1_varlist)
net2_varlist = {v.op.name.lstrip("net2/"): v
for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")}
net2_saver = tf.train.Saver(var_list=net2_varlist)
net1_saver.restore(sess, "net1.ckpt")
net2_saver.restore(sess, "net2.ckpt")
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With