Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow variable scope: reuse if variable exists

I want a piece of code that creates a variable within a scope if it doesn't exist, and access the variable if it already exists. I need it to be the same code since it will be called multiple times.

However, Tensorflow needs me to specify whether I want to create or reuse the variable, like this:

with tf.variable_scope("foo"): #create the first time     v = tf.get_variable("v", [1])  with tf.variable_scope("foo", reuse=True): #reuse the second time     v = tf.get_variable("v", [1]) 

How can I get it to figure out whether to create or reuse it automatically? I.e., I want the above two blocks of code to be the same and have the program run.

like image 963
holdenlee Avatar asked Jul 23 '16 18:07

holdenlee


People also ask

What is Variable_scope in Tensorflow?

Variable scope allows you to create new variables and to share already created ones while providing checks to not create or share by accident. For details, see the Variable Scope How To, here we present only a few basic examples. The Variable Scope works as expected when the Eager Execution is Disabled.

What is TF Get_variable?

The function tf. get_variable() returns the existing variable with the same name if it exists, and creates the variable with the specified shape and initializer if it does not exist.

What is TF Name_scope?

tf. name_scope( name, default_name=None, values=None ) This context manager validates that the given values are from the same graph, makes that graph the default graph, and pushes a name scope in that graph (see tf. Graph. name_scope for more details on that).


2 Answers

A ValueError is raised in get_variable() when creating a new variable and shape is not declared, or when violating reuse during variable creation. Therefore, you can try this:

def get_scope_variable(scope_name, var, shape=None):     with tf.variable_scope(scope_name) as scope:         try:             v = tf.get_variable(var, shape)         except ValueError:             scope.reuse_variables()             v = tf.get_variable(var)     return v  v1 = get_scope_variable('foo', 'v', [1]) v2 = get_scope_variable('foo', 'v') assert v1 == v2 

Note that the following also works:

v1 = get_scope_variable('foo', 'v', [1]) v2 = get_scope_variable('foo', 'v', [1]) assert v1 == v2 

UPDATE. The new API supports auto-reusing now:

def get_scope_variable(scope, var, shape=None):     with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):         v = tf.get_variable(var, shape)     return v 
like image 173
rvinas Avatar answered Sep 22 '22 14:09

rvinas


Although using "try...except..." clause works, I think a more elegant and maintainable way would be separate the variable initialization process with the "reuse" process.

def initialize_variable(scope_name, var_name, shape):     with tf.variable_scope(scope_name) as scope:         v = tf.get_variable(var_name, shape)         scope.reuse_variable()  def get_scope_variable(scope_name, var_name):     with tf.variable_scope(scope_name, reuse=True):         v = tf.get_variable(var_name)     return v 

Since often we only need to initialize variable ones, but reuse/share it for many times, separating the two processes make the code cleaner. Also this way, we won't need to go through the "try" clause every time to check if the variable has been created already or not.

like image 31
Zhongyu Kuang Avatar answered Sep 25 '22 14:09

Zhongyu Kuang