Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

connect input and output tensors of two different graphs tensorflow

I have 2 ProtoBuf Files, I currently load and forward pass each of them separately, by calling-

out1=session.run(graph1out, feed_dict={graph1inp:inp1})

followed by

final=session.run(graph2out, feed_dict={graph2inp:out1})

where graph1inp and graph1out are input node and output node of graph 1 and similar terminology for graph 2

Now, I want to connect graph1out with graph2inp such that I only have to run graph2out while feeding graph1inp with inp1. In other words connecting the input and output tensors of the 2 involved graphs in such a way that one run is sufficient to run inference on both trained ProtoBuf files.

like image 409
Pranay Mathur Avatar asked Mar 17 '17 13:03

Pranay Mathur


2 Answers

Assuming that your Protobuf files contain serialized tf.GraphDef protos, you can use the input_map argument of tf.import_graph_def() to connect the two graphs:

# Import graph1.
graph1_def = ...  # tf.GraphDef object
out1_name = "..."  # name of the graph1out tensor in graph1_def.
graph1out, = tf.import_graph_def(graph1_def, return_elements=[out_name])

# Import graph2 and connect it to graph1.
graph2_def = ...  # tf.GraphDef object
inp2_name = "..."  # name of the graph2inp tensor in graph2_def.
out2_name = "..."  # name of the graph2out tensor in graph2_def.
graph2out, = tf.import_graph_def(graph2_def, input_map={inp2_name: graph1out},
                                 return_elements=[out2_name])
like image 129
mrry Avatar answered Sep 19 '22 17:09

mrry


Accepted answer does connect two graphs, however it does not restore the collections, global and trainable variables. After an exhaustive search I came to a better solution:

import tensorflow as tf
from tensorflow.python.framework import meta_graph

with tf.Graph().as_default() as graph1:
    input = tf.placeholder(tf.float32, (None, 20), name='input')
    output = tf.identity(input, name='output')

with tf.Graph().as_default() as graph2:
    input = tf.placeholder(tf.float32, (None, 20), name='input')
    output = tf.identity(input, name='output')

graph = tf.get_default_graph()
x = tf.placeholder(tf.float32, (None, 20), name='input')

We use tf.train.export_meta_graph that exports also CollectionDef and meta_graph.import_scoped_meta_graph to import it. Here is where the connection happens, specifically in input_map parameter.

meta_graph1 = tf.train.export_meta_graph(graph=graph1)
meta_graph.import_scoped_meta_graph(meta_graph1, input_map={'input': x}, import_scope='graph1')
out1 = graph.get_tensor_by_name('graph1/output:0')

meta_graph2 = tf.train.export_meta_graph(graph=graph2)
meta_graph.import_scoped_meta_graph(meta_graph2, input_map={'input': out1}, import_scope='graph2')

Now graph is connected as well as global variables are being re-mapped.

print(tf.global_variables())

You can also import meta graphs directly from a file.

like image 21
Talmaj Avatar answered Sep 17 '22 17:09

Talmaj