Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow get all variables in scope

Tags:

tensorflow

I have some variables created within a certain scope like this:

with tf.variable_scope("my_scope"):
  createSomeVariables()
  ...

I then want to get the list of all the variables in "my_scope" so I can pass it to an optimizer. What is the right way to do this?

like image 864
Vlad Firoiu Avatar asked Apr 10 '16 18:04

Vlad Firoiu


2 Answers

I think you want tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='my_scope'). This will get all variables in a scope.

To pass to an optimizer you do not want all variables you would just want the trainable variables. Those are also kept in a default collection, which is tf.GraphKeys.TRAINABLE_VARIABLES.

like image 194
user728291 Avatar answered Nov 20 '22 05:11

user728291


User correctly pointed out that you need tf.get_collection(). I will just give a simple example how to do this:

import tensorflow as tf

with tf.name_scope('some_scope1'):
    a = tf.Variable(1, 'a')
    b = tf.Variable(2, 'b')
    c = tf.Variable(3, 'c')

with tf.name_scope('some_scope2'):
    d = tf.Variable(4, 'd')
    e = tf.Variable(5, 'e')
    f = tf.Variable(6, 'f')

h = tf.Variable(8, 'h')

for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='some_scope'):
    print i   # i.name if you want just a name

Notice that you can provide any of the graphKeys and scope is a regular expression:

scope: (Optional.) If supplied, the resulting list is filtered to include only items whose name attribute matches using re.match. Items without a name attribute are never returned if a scope is supplied and the choice or re.match means that a scope without special tokens filters by prefix.

So if you will pass 'some_scope' you will get 6 variables.

like image 16
Salvador Dali Avatar answered Nov 20 '22 05:11

Salvador Dali