Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What does it mean that a tf.variable is trainable in TensorFlow

This question came to me when I read the documentation of global_step. Here it explicitly declares global_step is not trainable.

global_step_tensor = tf.Variable(10, trainable=False, name='global_step')

sess = tf.Session()

print('global_step: %s' % tf.train.global_step(sess, global_step_tensor))

From my understanding, trainable means that the value could be changed during sess.run(). I have tried to declare it both trainable and non-trainable and got the same results. So I didn't understand why we need to declare it not trainable.

I read the documentation of trainable but didn't quite get it.

So my question is:

  1. Can non-trainable variable value be changed during sess.run() and vice versa?
  2. What is the point that make a variable not trainable?
like image 999
Jekyll SONG Avatar asked Jan 26 '18 11:01

Jekyll SONG


1 Answers

From my understanding, trainable means that the value could be changed during sess.run()

That is not the definition of a trainable variable. Any variable can be modified during a sess.run() (That's why they are variables and not constants).

The distinction between trainable variables and non-trainable variables is used to let Optimizers know which variables they can act upon. When defining a tf.Variable(), setting trainable=True (the default) automatically adds the variable to the GraphKeys.TRAINABLE_VARIABLES collection. During training, an optimizer gets the content of that collection via tf.trainable_variables() and applies the training to all of them.

The typical example of a non-trainable variable is global_step, because its value does change over time (+1 at each training iteration, typically), but you don't want to apply an optimization algorithm to it.

like image 87
GPhilo Avatar answered Oct 10 '22 14:10

GPhilo