Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow: getting variable by name

Tags:

tensorflow

When using the TensorFlow Python API, I created a variable (without specifying its name in the constructor), and its name property had the value "Variable_23:0". When I try to select this variable using tf.get_variable("Variable23"), a new variable called "Variable_23_1:0" is created instead. How do I correctly select "Variable_23" instead of creating a new one?

What I want to do is select the variable by name, and reinitialize it so I can finetune weights.

like image 851
user3528623 Avatar asked Feb 28 '16 04:02

user3528623


2 Answers

The get_variable() function creates a new variable or returns one created earlier by get_variable(). It won't return a variable created using tf.Variable(). Here's a quick example:

>>> with tf.variable_scope("foo"): ...   bar1 = tf.get_variable("bar", (2,3)) # create ...  >>> with tf.variable_scope("foo", reuse=True): ...   bar2 = tf.get_variable("bar")  # reuse ...   >>> with tf.variable_scope("", reuse=True): # root variable scope ...   bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above) ...  >>> (bar1 is bar2) and (bar2 is bar3) True 

If you did not create the variable using tf.get_variable(), you have a couple options. First, you can use tf.global_variables() (as @mrry suggests):

>>> bar1 = tf.Variable(0.0, name="bar") >>> bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0] >>> bar1 is bar2 True 

Or you can use tf.get_collection() like so:

>>> bar1 = tf.Variable(0.0, name="bar") >>> bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0] >>> bar1 is bar2 True 

Edit

You can also use get_tensor_by_name():

>>> bar1 = tf.Variable(0.0, name="bar") >>> graph = tf.get_default_graph() >>> bar2 = graph.get_tensor_by_name("bar:0") >>> bar1 is bar2 False, bar2 is a Tensor througn convert_to_tensor on bar1. but bar1 equal  bar2 in value. 

Recall that a tensor is the output of an operation. It has the same name as the operation, plus :0. If the operation has multiple outputs, they have the same name as the operation plus :0, :1, :2, and so on.

like image 82
MiniQuark Avatar answered Oct 07 '22 16:10

MiniQuark


The easiest way to get a variable by name is to search for it in the tf.global_variables() collection:

var_23 = [v for v in tf.global_variables() if v.name == "Variable_23:0"][0] 

This works well for ad hoc reuse of existing variables. A more structured approach—for when you want to share variables between multiple parts of a model—is covered in the Sharing Variables tutorial.

like image 42
mrry Avatar answered Oct 07 '22 16:10

mrry