Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do you use TensorFlow Graphkeys to get all weights?

Tensorflow defines a preset of collections as can be seen here: https://www.tensorflow.org/versions/r0.12/api_docs/python/framework/graph_collections

I am currently using tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) to get all the variables [*which are named; if they are not then they are not displayed, even if they exist].

Similarly, I expected tf.get_collection(tf.GraphKeys.WEIGHTS) to output a list of weights, but rather it is an empty array. This also holds true for GraphKeys.BIASES and .ACTIVATIONS.

What's going on here?

In my opinion, there seems to be two possibilities here. First, these are are actually never automatically defined, and are just recommended collection names. Second, my network is very broken, but that doesn't seem to be the case.

Anyone have experience with this?

like image 337
SStensors Avatar asked Jul 22 '17 16:07

SStensors


2 Answers

By default all variables are bounded to the tf.GraphKeys.GLOBAL_VARIABLES collection. The convenient method is to set each weight to the collection tf.GraphKeys.WEIGHTS like this:

In [2]: w = tf.Variable([1,2,3], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)
In [3]: w2 = tf.Variable([11,22,32], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)

Then you can fetch them by:

tf.get_collection_ref(tf.GraphKeys.WEIGHTS)

And this is the weights:

[<tf.Variable 'Variable:0' shape=(3,) dtype=float32_ref>,
 <tf.Variable 'Variable_1:0' shape=(3,) dtype=float32_ref>]
like image 197
Lerner Zhang Avatar answered Oct 19 '22 19:10

Lerner Zhang


From the docs:

The following standard keys are defined, but their collections are not automatically populated as many of the others are:

  • WEIGHTS
  • BIASES
  • ACTIVATIONS
like image 3
P-Gn Avatar answered Oct 19 '22 21:10

P-Gn