If I attempt to import a saved TensorFlow graph definition with
import tensorflow as tf
from tensorflow.python.platform import gfile
with gfile.FastGFile(FLAGS.model_save_dir.format(log_id) + '/graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
x, y, y_ = tf.import_graph_def(graph_def,
return_elements=['data/inputs',
'output/network_activation',
'data/correct_outputs'],
name='')
the returned values are not Tensor
s as expected, but something else: instead, for example, of getting x
as
Tensor("data/inputs:0", shape=(?, 784), dtype=float32)
I get
name: "data/inputs_1"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
}
}
}
That is, instead of getting the expected tensor x
I get, x.op
. This confuses me because the documentation seems to say I should get a Tensor
(though there are a bunch of ors there that make it hard to understand).
How do I get tf.import_graph_def
to return specific Tensor
s that I can then use (e.g. in feeding the loaded model, or running analyses)?
The names 'data/inputs'
, 'output/network_activation'
and 'data/correct_outputs'
are actually operation names. To get tf.import_graph_def()
to return tf.Tensor
objects, you should append the an output index to the operation name, which is typically ':0'
for single-output operations:
x, y, y_ = tf.import_graph_def(graph_def,
return_elements=['data/inputs:0',
'output/network_activation:0',
'data/correct_outputs:0'],
name='')
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With