Tensorflow defines a preset of collections as can be seen here: https://www.tensorflow.org/versions/r0.12/api_docs/python/framework/graph_collections
I am currently using tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
to get all the variables [*which are named; if they are not then they are not displayed, even if they exist].
Similarly, I expected tf.get_collection(tf.GraphKeys.WEIGHTS)
to output a list of weights, but rather it is an empty array. This also holds true for GraphKeys.BIASES
and .ACTIVATIONS
.
What's going on here?
In my opinion, there seems to be two possibilities here. First, these are are actually never automatically defined, and are just recommended collection names. Second, my network is very broken, but that doesn't seem to be the case.
Anyone have experience with this?
By default all variables are bounded to the tf.GraphKeys.GLOBAL_VARIABLES collection. The convenient method is to set each weight to the collection tf.GraphKeys.WEIGHTS like this:
In [2]: w = tf.Variable([1,2,3], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)
In [3]: w2 = tf.Variable([11,22,32], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)
Then you can fetch them by:
tf.get_collection_ref(tf.GraphKeys.WEIGHTS)
And this is the weights:
[<tf.Variable 'Variable:0' shape=(3,) dtype=float32_ref>,
<tf.Variable 'Variable_1:0' shape=(3,) dtype=float32_ref>]
From the docs:
The following standard keys are defined, but their collections are not automatically populated as many of the others are:
WEIGHTS
BIASES
ACTIVATIONS
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With