Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What's the differences between tf.GraphKeys.TRAINABLE_VARIABLES and tf.GraphKeys.UPDATE_OPS in tensorflow?

Tags:

tensorflow

Here is doc of tf.GraphKeys in tensorflow, such as TRAINABLE_VARIABLES: the subset of Variable objects that will be trained by an optimizer.

And i know tf.get_collection(), which can find some tensor that you want.

When use tensorflow.contrib.layers.batch_norm(), the parameter updates_collections default value is GraphKeys.UPDATE_OPS.

How can we understand those collections, and difference in them.

Besides, we can find more in ops.py.

like image 674
dxf Avatar asked Jan 15 '18 09:01

dxf


1 Answers

These are two different things.

TRAINABLE_VARIABLES

TRAINABLE_VARIABLES is the collection of variables or training parameters which should be modified when minimizing the loss. For example, these can be the weights determining the function performed by each node in the network.

How do variables get added to this collection? This happens automatically when you define a new variable with tf.get_variable, unless you specify

tf.get_variable(..., trainable=False)

When would you want a variable to be untrainable? This happens from time to time. For example, occasionally you will want to use a two-step approach in which you first train the entire network on a large, generic dataset, then fine-tune the network on a smaller dataset which is specifically related to your problem. In such cases, you might want to fine-tune only part of the network, e.g., the last layer. Specifying some variables as untrainable is one of the ways to do this.

UPDATE_OPS

UPDATE_OPS is a collection of ops (operations performed when the graph runs, like multiplication, ReLU, etc.), not variables. Specifically, this collection maintains a list of ops which need to run before each training step.

How do ops get added to this collection? By definition, update_ops occur outside the regular flow of training by loss minimization, so generally you will be adding ops to this collection only under special circumstances. For example, when performing batch normalization, you want to recompute the batch mean and variance before each training step, and this is how it's done. The mechanics of batch normalization using tf.contrib.layers.batch_norm are described in more detail in this article.

like image 79
Zvika Avatar answered Oct 12 '22 12:10

Zvika