Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Cloning a network with tf.contrib.graph_editor

Tags:

tensorflow

I have a function like this that builds a network.

def build_network(inputs):
  # Some arbitrary set of variables and ops here. For example...
  out = tf.contrib.layers.fully_connected(inputs, 123)
  (...)
  return out

Then I use it to build a network like this.

inputs = tf.placeholder(...)
outputs = build_network(inputs)

If I wanted to build more networks with identical structure but independent variables I'd just have to call build_network again under some other variable scope and, optionally, other inputs.

My question is: how can I do this if this build_network is no longer available, but the inputs and outputs of the original network are? In other words: how can I clone the entire subgraph from outputs all the way to inputs into another variable scope with its own independent set of variables but identical structure?

My understanding is that tf.contrib.graph_editor in general and graph_editor.copy in particular are precisely the tools I need to do these kind of things. However, I couldn't find any good example of their use. Any suggestions?

like image 233
user3176103 Avatar asked Aug 26 '17 14:08

user3176103


2 Answers

Responding myself, I have found what looks like a way to copy the subgraph.

from tensorflow.contrib import graph_editor as ge

# From the example above.
inputs = [tf.placeholder(...), ...]
outputs = build_network(inputs)

sgv = ge.make_view(ge.get_within_boundary_ops(
    tf.get_default_graph(),
    [t.op for t in outputs],
    [t.op for t in inputs]))

# This could be any new inputs. In this example I build new identical placeholders.
new_inputs = {p: tf.placeholder(dtype=p.dtype, shape=p.shape) for p in inputs}
new_sgv, info = ge.copy_with_input_replacements(sgv, new_inputs, dst_scope='copy')

new_inputs = [info.transformed(t) for t in inputs]
new_outputs = [info.transformed(t) for t in outputs]

However, now I'm facing a new problem when trying to use the network copy. The new variables in the copy are not initialized, and trying to run tf.global_variables_initializer() does not help.

The reason is that since the tf.Variable for these were never built, they are not part of the GlobalKeys.GLOBAL_VARIABLES collection. I can easily find the ops corresponding to these variables and their mapping between original and copy, but I cannot build a tf.Variable from that.

I found bit of a hacky workaround to do the initialization, but it only works for vars in a collection.

init_ops = []
for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
  if v.op in sgv.ops:
    init_ops.append(info.transformed(v.initializer))

...

session.run([tf.global_variables_initializer()] + init_ops)

Is there any better way to do this? Ideally one that allows creating tf.Variables for the copied vars to add them to the global vars collection. Or, if that weren't possible, at least a reliable way to get the initializer ops without having to find the tf.Variable objects of the original network.

like image 166
user3176103 Avatar answered Nov 03 '22 07:11

user3176103


Note: this answer and the OP's answer are complementary to each other. Read OP's answer first.

I've spent 4 hours today on this issue. This is one of the places where the ugliness of TensorFlow unfolds (and that's why you should use PyTorch if graph manipulation is your thing).


The crucial point here is: the tf.Variable is NOT a graph element (more on it here) but a wrapper around 3 ops: the Assign op, the Read op, and the VariableV2 op which is essentially a ref tensor (more on it here). So, it is something you need to call explicitly in the TensorFlow Framework.

If we look closely at the graph_editor's code, especially the transform module, we can see that it operates only on the tf.Graph, not touching anything from the TensorFlow Framework. So, the graph_editor.copy (and similar) methods does not touch tf.Variable objects at all. It only copies the tensors and ops that are building blocks of tf.Variable.

Okay, then how do we solve this problem?

Suppose you have the following variable:

var = tf.get_trainable_variables()[0]
print(var.to_proto())
# variable_name: "dense_1/kernel:0"
# initializer_name: "dense_1/kernel/Assign"
# snapshot_name: "dense_1/kernel/read:0"
# initial_value_name: "dense_1/random_uniform:0"
# trainable: true

You know that after graph_editor.copy(...), your dense_1 name scope is now dense_1b. Then, all you need is use info.transformed(...) to get the corresponding ops and tensors, and do the following:

from tensorflow.core.framework import variable_pb2

var_def = variable_pb2.VariableDef()
var_def.variable_name = 'dense_1b/kernel:0'
var_def.initializer_name = "dense_1b/kernel/Assign"
var_def.snapshot_name = "dense_1b/kernel/read:0"
var_def.initial_value_name = "dense_1/random_uniform:0"
var_def.trainable = True

Now, I want to emphasize on the following part of tf.Variable documentation:

variable_def: ... recreates the Variable object with its contents, referencing the variable's nodes in the graph, which must already exist. The graph is not changed.

So, the tf.Variable constructor allows us to create a Variable wrapper on top of existing graph elements. That's exactly what we need:

cloned_var = tf.Variable(variable_def=var_def)
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, cloned_var)

Solved!


I kept this answer as simple and specific as possible to show the underlying mechanics of tf.Variables. You can now easily implement the code for more general case to make new variables automatically.

PS: I hate TensorFlow!

like image 2
FalconUA Avatar answered Nov 03 '22 06:11

FalconUA