Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tf how to restore two variables from the same variable

I have saved a model and now I am trying to restore it in two branches, like this:

enter image description here

I wrote this code, and it raises ValueError: The same saveable will be restored with two names. How do I restore two variables from the same variable?

restore_variables = {}
for varr in tf.global_variables()
    if varr.op.name in checkpoint_variables:
        restore_variables[varr.op.name.split("_red")[0]] = varr           
        restore_variables[varr.op.name.split("_blue")[0]] = varr
init_saver = tf.train.Saver(restore_variables, max_to_keep=0)
like image 748
Aleeee Avatar asked Feb 12 '20 03:02

Aleeee


2 Answers

Tested on TF 1.15

Basically the error is saying that it's finding multiple references to the same variable in the restore_variables dict. The fix is simple. Create a copy of your variable using tf.Variable(varr) as follows for one of the references.

I think it's safe to assume that you're not looking for multiple references to the same variable here, rather two separate variables. (I'm assuming this because, if you want to use the same variable multiple times, you can just use the single variable multiple times).

with tf.Session() as sess:
    saver.restore(sess, './vars/vars.ckpt-0')
    restore_variables = {}
    checkpoint_variables=['b']
    for varr in tf.global_variables():
        if varr.op.name in checkpoint_variables:
            restore_variables[varr.op.name.split("_red")[0]] = varr           
            restore_variables[varr.op.name.split("_blue")[0]] = tf.Variable(varr)
    print(restore_variables)
    init_saver = tf.train.Saver(restore_variables, max_to_keep=0)

Below you can find the full code to replicate the issue using a toy example. Essentially, we have two variables a and b and out of that, we are creating b_red and b_blue variables.

# Saving the variables

import tensorflow as tf
import numpy as np
a = tf.placeholder(shape=[None, 3], dtype=tf.float64)
w1 = tf.Variable(np.random.normal(size=[3,2]), name='a')
out = tf.matmul(a, w1)
w2 = tf.Variable(np.random.normal(size=[2,3]), name='b')
out = tf.matmul(out, w2)

saver = tf.train.Saver([w1, w2])

with tf.Session() as sess:
  tf.global_variables_initializer().run()
  saved_path = saver.save(sess, './vars/vars.ckpt', global_step=0)
# Restoring the variables

with tf.Session() as sess:
    saver.restore(sess, './vars/vars.ckpt-0')
    restore_variables = {}
    checkpoint_variables=['b']
    for varr in tf.global_variables():
        if varr.op.name in checkpoint_variables:
            restore_variables[varr.op.name+"_red"] = varr  
            # Fixing the issue: Instead of varr, do tf.Variable(varr)
            restore_variables[varr.op.name+"_blue"] = varr
    print(restore_variables)
    init_saver = tf.train.Saver(restore_variables, max_to_keep=0)
like image 88
thushv89 Avatar answered Nov 15 '22 10:11

thushv89


I may not be understanding the problem correctly, but can't you just make two saver objects? Something like this:

import tensorflow as tf

# Make checkpoint
with tf.Graph().as_default(), tf.Session() as sess:
    a = tf.Variable([1., 2.], name='a')
    sess.run(a.initializer)
    b = tf.Variable([3., 4., 5.], name='b')
    sess.run(b.initializer)
    saver = tf.train.Saver([a, b])
    saver.save(sess, 'tmp/vars.ckpt')

# Restore checkpoint
with tf.Graph().as_default(), tf.Session() as sess:
    # Red
    a_red = tf.Variable([0., 0.], name='a_red')
    b_red = tf.Variable([0., 0., 0.], name='b_red')
    saver_red = tf.train.Saver({'a': a_red, 'b': b_red})
    saver_red.restore(sess, 'tmp1/vars.ckpt')
    print(a_red.eval())
    # [1. 2.]
    print(b_red.eval())
    # [3. 4. 5.]

    # Blue
    a_blue = tf.Variable([0., 0.], name='a_blue')
    b_blue = tf.Variable([0., 0., 0.], name='b_blue')
    saver_blue = tf.train.Saver({'a': a_blue, 'b': b_blue})
    saver_blue.restore(sess, 'tmp/vars.ckpt')
    print(a_blue.eval())
    # [1. 2.]
    print(b_blue.eval())
    # [3. 4. 5.]
like image 33
jdehesa Avatar answered Nov 15 '22 10:11

jdehesa