My tensorflow version is 0.11. I want to save a graph after training or save something else which tensorflow can load it.
I/ Using Exporting and Importing a MetaGraph
I already read this post: Tensorflow: how to save/restore a model?
My Save.py file:
X = tf.placeholder("float", [None, 28, 28, 1], name='X')
Y = tf.placeholder("float", [None, 10], name='Y')
tf.train.Saver()
with tf.Session() as sess:
...run something ...
final_tensor = tf.nn.softmax(py_x, name='final_result')
tf.add_to_collection("final_tensor", final_tensor)
predict_op = tf.argmax(py_x, 1)
tf.add_to_collection("predict_op", predict_op)
saver.save(sess, 'my_project')
Then I run load.py:
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my_project.meta')
new_saver.restore(sess, 'my_project')
predict_op = tf.get_collection("predict_op")[0]
for i in range(2):
test_indices = np.arange(len(teX)) # Get A Test Batch
np.random.shuffle(test_indices)
test_indices = test_indices[0:test_size]
print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
sess.run(predict_op, feed_dict={"X:0": teX[test_indices],
"p_keep_conv:0": 1.0,
"p_keep_hidden:0": 1.0})))
but it return error
Traceback (most recent call last):
File "load_05_convolution.py", line 62, in <module>
"p_keep_hidden:0": 1.0})))
File "/home/khoa/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 717, in run
run_metadata_ptr)
File "/home/khoa/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 894, in _run
% (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
ValueError: Cannot feed value of shape (256, 784) for Tensor u'X:0', which has shape '(?, 28, 28, 1)'
I really don't know why?
If I add final_tensor = tf.get_collection("final_result")[0]
It return another error:
Traceback (most recent call last):
File "load_05_convolution.py", line 46, in <module>
final_tensor = tf.get_collection("final_result")[0]
IndexError: list index out of range
Is it because tf.add_to_collection only contains only one place holder ?
II/ using tf.train.write_graph
I add this line to the end of the save.py
tf.train.write_graph(graph, 'folder', 'train.pb')
It created file 'train.pb' successfully
My load.py:
with tf.gfile.FastGFile('folder/train.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(graph_def, name='')
with tf.Session() as sess:
predict_op = sess.graph.get_tensor_by_name('predict_op:0')
for i in range(2):
test_indices = np.arange(len(teX)) # Get A Test Batch
np.random.shuffle(test_indices)
test_indices = test_indices[0:test_size]
print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
sess.run(predict_op, feed_dict={"X:0": teX[test_indices],
"p_keep_conv:0": 1.0,
"p_keep_hidden:0": 1.0})))
Then it return error:
Traceback (most recent call last):
File "load_05_convolution.py", line 22, in <module>
graph_def.ParseFromString(f.read())
File "/home/khoa/tensorflow/lib/python2.7/site-packages/google/protobuf/message.py", line 185, in ParseFromString
self.MergeFromString(serialized)
File "/home/khoa/tensorflow/lib/python2.7/site-packages/google/protobuf/internal/python_message.py", line 1085, in MergeFromString
raise message_mod.DecodeError('Unexpected end-group tag.')
google.protobuf.message.DecodeError: Unexpected end-group tag.
would you mind sharing the standard way, code or tutorial to save/load model ? I'm really confused.
Your first solution (using the MetaGraph) almost works, but the error arises because you are feeding a batch of flattened MNIST training examples to a tf.placeholder()
that expects a batch of MNIST training examples as a 4-D tensor with shape batch_size
x height
(= 28) x width
(= 28) x channels
(= 1). The easiest way to solve this is to reshape your input data. Instead of this statement:
print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
sess.run(predict_op, feed_dict={
"X:0": teX[test_indices],
"p_keep_conv:0": 1.0,
"p_keep_hidden:0": 1.0})))
...try the following statement, which reshapes your input data appropriately, instead:
print(i, np.mean(np.argmax(teY[test_indices], axis=1) ==
sess.run(predict_op, feed_dict={
"X:0": teX[test_indices].reshape(-1, 28, 28, 1),
"p_keep_conv:0": 1.0,
"p_keep_hidden:0": 1.0})))
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With