Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Save specific weights in TensorFlow

Tags:

tensorflow

In my neural network, I create some tf.Variable objects as follows:

weights = {
    'wc1_0': tf.Variable(tf.random_normal([5, 5, 3, 64])),
    'wc1_1': tf.Variable(tf.random_normal([5, 5, 3, 64]))
}
biases = {
    'bc1_0': tf.Variable(tf.constant(0.0, shape=[64])),
    'bc1_1': tf.Variable(tf.constant(0.0, shape=[64]))
}

How would I save the variables in weights and biases after a specific number of iterations without saving other variables?

like image 642
lhao0301 Avatar asked Sep 12 '16 12:09

lhao0301


1 Answers

The standard way to save variables in TensorFlow is to use a tf.train.Saver object. By default it saves all of the variables in your problem (i.e., the results of tf.all_variables()), but you can save variables selectively by passing the var_list optional argument to the tf.train.Saver constructor:

weights = {
    'wc1_0': tf.Variable(tf.random_normal([5, 5, 3, 64])),
    'wc1_1': tf.Variable(tf.random_normal([5, 5, 3, 64]))
}
biases = {
    'bc1_0': tf.Variable(tf.constant(0.0, shape=[64])),
    'bc1_1': tf.Variable(tf.constant(0.0, shape=[64]))
}

# Define savers for explicit subsets of the variables.
weights_saver = tf.train.Saver(var_list=weights)
biases_saver = tf.train.Saver(var_list=biases)

# ...
# You need a TensorFlow Session to save variables.
sess = tf.Session()
# ...

# ...then call the following methods as appropriate:
weights_saver.save(sess)  # Save the current value of the weights.
biases_saver.save(sess)   # Save the current value of the biases.

Note that if you pass a dictionary to the tf.train.Saver constructor (such as the weights and/or biases dictionaries from your question), TensorFlow will use the dictionary key (e.g. 'wc1_0') as the name for the corresponding variable in any checkpoint files it creates or consumes.

By default, or if you pass a list of tf.Variable objects to the constructor, TensorFlow will use the tf.Variable.name property instead.

Passing a dictionary gives you the ability to share checkpoints between models that give different Variable.name properties to each variable. This detail is only important if you want to use the created checkpoints with another model.

like image 106
mrry Avatar answered Oct 07 '22 16:10

mrry