When using the TensorFlow Python API, I created a variable (without specifying its name
in the constructor), and its name
property had the value "Variable_23:0"
. When I try to select this variable using tf.get_variable("Variable23")
, a new variable called "Variable_23_1:0"
is created instead. How do I correctly select "Variable_23"
instead of creating a new one?
What I want to do is select the variable by name, and reinitialize it so I can finetune weights.
The get_variable()
function creates a new variable or returns one created earlier by get_variable()
. It won't return a variable created using tf.Variable()
. Here's a quick example:
>>> with tf.variable_scope("foo"): ... bar1 = tf.get_variable("bar", (2,3)) # create ... >>> with tf.variable_scope("foo", reuse=True): ... bar2 = tf.get_variable("bar") # reuse ... >>> with tf.variable_scope("", reuse=True): # root variable scope ... bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above) ... >>> (bar1 is bar2) and (bar2 is bar3) True
If you did not create the variable using tf.get_variable()
, you have a couple options. First, you can use tf.global_variables()
(as @mrry suggests):
>>> bar1 = tf.Variable(0.0, name="bar") >>> bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0] >>> bar1 is bar2 True
Or you can use tf.get_collection()
like so:
>>> bar1 = tf.Variable(0.0, name="bar") >>> bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0] >>> bar1 is bar2 True
Edit
You can also use get_tensor_by_name()
:
>>> bar1 = tf.Variable(0.0, name="bar") >>> graph = tf.get_default_graph() >>> bar2 = graph.get_tensor_by_name("bar:0") >>> bar1 is bar2 False, bar2 is a Tensor througn convert_to_tensor on bar1. but bar1 equal bar2 in value.
Recall that a tensor is the output of an operation. It has the same name as the operation, plus :0
. If the operation has multiple outputs, they have the same name as the operation plus :0
, :1
, :2
, and so on.
The easiest way to get a variable by name is to search for it in the tf.global_variables()
collection:
var_23 = [v for v in tf.global_variables() if v.name == "Variable_23:0"][0]
This works well for ad hoc reuse of existing variables. A more structured approach—for when you want to share variables between multiple parts of a model—is covered in the Sharing Variables tutorial.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With