Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to create a non-trainable variable in Tensorflow?

Tags:

tensorflow

Does it exist a parameter that specifies a tf.Variable as non-trainable, so that the variable is not included in tf.trainable_variables()?

like image 355
cerebrou Avatar asked Dec 10 '22 14:12

cerebrou


2 Answers

You can mark variables as "non-trainable" on definition:

v = tf.Variable(tf.zeros([1]), trainable=False)

From the linked documentation (circa TensorFlow v0.11):

trainable: If True, the default, also adds the variable to the graph collection GraphKeys.TRAINABLE_VARIABLES. This collection is used as the default list of variables to use by the Optimizer classes.

There are also ways to change this condition with APIs such as tf.get_variable([v]).

like image 107
Eric Platon Avatar answered Jan 21 '23 12:01

Eric Platon


You can create non-trainable variables in two different ways:

  • tf.Variable(a, trainable=False)
  • tf.get_variable("a", a, trainable=False)

There is no easy way to change the variable from trainable to non-trainable and otherwise. Also there is no easy way to check whether the variable is trainable (you need to check whether the name of your variable is in the list of tf.trainable_variables()

like image 25
Salvador Dali Avatar answered Jan 21 '23 12:01

Salvador Dali