Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to restore variables using CheckpointReader in Tensorflow

I'm trying to restore some variables from checkpoint file if same variable name is in current model.
And I found that there is some way as in Tensorfow Github

So what I want to do is checking variable names in checkpoint file using has_tensor("variable.name") as below,

...    
reader = tf.train.NewCheckpointReader(ckpt_path)
for v in tf.trainable_variables():
    print v.name
    if reader.has_tensor(v.name):
        print 'has tensor'
...

But I found that v.name returns both variable name and colon+number. For example, I have variable name W_o and b_o then v.name returns W_o:0, b_o:0.

However reader.has_tensor() requires name without colon and number as W_o, b_o.

My question is: how to remove the colon and number at the end of the variable name in order to read the variables?
Is there a better way to restore such variables?

like image 430
user270700 Avatar asked Apr 15 '26 12:04

user270700


1 Answers

You could use string.split() to get the tensor name:

...    
reader = tf.train.NewCheckpointReader(ckpt_path)
for v in tf.trainable_variables():
    tensor_name = v.name.split(':')[0]
    print tensor_name
    if reader.has_tensor(tensor_name):
        print 'has tensor'
...

Next, let me use an example to show how I would restore every possible variable from a .cpkt file. First, let's save v2 and v3 in tmp.ckpt:

import tensorflow as tf

v1 = tf.Variable(tf.ones([1]), name='v1')
v2 = tf.Variable(2 * tf.ones([1]), name='v2')
v3 = tf.Variable(3 * tf.ones([1]), name='v3')

saver = tf.train.Saver({'v2': v2, 'v3': v3})

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    saver.save(sess, 'tmp.ckpt')

That's how I would restore every variable (belonging to a new graph) showing up in tmp.ckpt:

with tf.Graph().as_default():
    assert len(tf.trainable_variables()) == 0
    v1 = tf.Variable(tf.zeros([1]), name='v1')
    v2 = tf.Variable(tf.zeros([1]), name='v2')

    reader = tf.train.NewCheckpointReader('tmp.ckpt')
    restore_dict = dict()
    for v in tf.trainable_variables():
        tensor_name = v.name.split(':')[0]
        if reader.has_tensor(tensor_name):
            print('has tensor ', tensor_name)
            restore_dict[tensor_name] = v

    saver = tf.train.Saver(restore_dict)
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        saver.restore(sess, 'tmp.ckpt')
        print(sess.run([v1, v2])) # prints [array([ 0.], dtype=float32), array([ 2.], dtype=float32)]

Also, you may want to ensure that shapes and dtypes match.

like image 102
rvinas Avatar answered Apr 18 '26 09:04

rvinas