Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow: How to replace a node in a calculation graph?

If you have two disjoint graphs, and want to link them, turning this:

x = tf.placeholder('float')
y = f(x)

y = tf.placeholder('float')
z = f(y)

into this:

x = tf.placeholder('float')
y = f(x)
z = g(y)

Is there a way to do that? It seems like it could make construction easier in some cases.

For example if you have a graph that has the input image as a tf.placeholder, and want to optimize the input image, deep-dream style, is there a way to just replace the placeholder with a tf.variable node? Or do you have to think of that before building the graph?

like image 972
mdaoust Avatar asked Nov 17 '15 03:11

mdaoust


3 Answers

TL;DR: If you can define the two computations as Python functions, you should do that. If you can't, there's more advanced functionality in TensorFlow to serialize and import graphs, which allows you to compose graphs from different sources.

One way to do this in TensorFlow is to build the disjoint computations as separate tf.Graph objects, then convert them to serialized protocol buffers using Graph.as_graph_def():

with tf.Graph().as_default() as g_1:
  input = tf.placeholder(tf.float32, name="input")
  y = f(input)
  # NOTE: using identity to get a known name for the output tensor.
  output = tf.identity(y, name="output")

gdef_1 = g_1.as_graph_def()

with tf.Graph().as_default() as g_2:  # NOTE: g_2 not g_1       
  input = tf.placeholder(tf.float32, name="input")
  z = g(input)
  output = tf.identity(y, name="output")

gdef_2 = g_2.as_graph_def()

Then you could compose gdef_1 and gdef_2 into a third graph, using tf.import_graph_def():

with tf.Graph().as_default() as g_combined:
  x = tf.placeholder(tf.float32, name="")

  # Import gdef_1, which performs f(x).
  # "input:0" and "output:0" are the names of tensors in gdef_1.
  y, = tf.import_graph_def(gdef_1, input_map={"input:0": x},
                           return_elements=["output:0"])

  # Import gdef_2, which performs g(y)
  z, = tf.import_graph_def(gdef_2, input_map={"input:0": y},
                           return_elements=["output:0"]
like image 84
mrry Avatar answered Nov 15 '22 22:11

mrry


If you want to combine trained models (for example to reuse parts of a pretrained model in a new model), you can use a Saver to save a checkpoint of the first model, then restore that model (entirely or partially) into another model.

For example, say you want to reuse model 1's weights w in model 2, and also convert x from a placeholder to a variable:

with tf.Graph().as_default() as g1:
    x = tf.placeholder('float')
    w = tf.Variable(1., name="w")
    y = x * w
    saver = tf.train.Saver()

with tf.Session(graph=g1) as sess:
    w.initializer.run()
    # train...
    saver.save(sess, "my_model1.ckpt")

with tf.Graph().as_default() as g2:
    x = tf.Variable(2., name="v")
    w = tf.Variable(0., name="w")
    z = x + w
    restorer = tf.train.Saver([w]) # only restore w

with tf.Session(graph=g2) as sess:
    x.initializer.run()  # x now needs to be initialized
    restorer.restore(sess, "my_model1.ckpt") # restores w=1
    print(z.eval())  # prints 3.
like image 28
MiniQuark Avatar answered Nov 15 '22 21:11

MiniQuark


It turns out that tf.train.import_meta_graph passes all additional arguments to the underlying import_scoped_meta_graph which has the input_map argument and utilizes it when it gets to it's own (internal) invocation of import_graph_def.

It is not documented, and took me waaaay toooo much time to find it, but it works!

like image 6
Jonan Gueorguiev Avatar answered Nov 15 '22 22:11

Jonan Gueorguiev