I'm training my model using TensorFlow in C++. Python is used only for constructing the graph. So is there a way to save and restore the graph and its state purely in C++? I know about the Python class tf.train.Saver
but as far as I understand it does not exist in C++.
The tf.train.Saver
class currently exists only in Python, but (i) it is built from TensorFlow ops that you can run from C++, and (ii) it exposes the Saver.as_saver_def()
method that lets you get a SaverDef
protocol buffer with the names of ops that you must run to save or restore a model.
In Python, you can get the names of the save and restore ops as follows:
saver = tf.train.Saver(...)
saver_def = saver.as_saver_def()
# The name of the tensor you must feed with a filename when saving/restoring.
print saver_def.filename_tensor_name
# The name of the target operation you must run when restoring.
print saver_def.restore_op_name
# The name of the target operation you must run when saving.
print saver_def.save_tensor_name
In C++ to restore from a checkpoint, you call Session::Run()
, feeding in the name of the checkpoint file as saver_def.filename_tensor_name
, with a target op of saver_def.restore_op_name
. To save another checkpoint, you call Session::Run()
, again feeding in the name of the checkpoint file as saver_def.filename_tensor_name
, and fetching the value of saver_def.save_tensor_name
.
The recent TensorFlow version includes some helper functions to do the same in C++ without Python. These are generate from the ProtoBuf in the pip-package (${HOME}/.local/lib/python2.7/site-packages/tensorflow/include/tensorflow/core/protobuf/saver.pb.h
).
// save
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = {{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}};
status = sess->Run(feed_dict, {}, {graph_def.saver_def().save_tensor_name()}, nullptr);
// restore
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = {{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}};
status = sess->Run(feed_dict, {}, {graph_def.saver_def().restore_op_name()}, nullptr);
This is based on the undocumented python-way (more details) of restoring a model
def restore(sess, metaGraph, fn):
restore_op_name = metaGraph.as_saver_def().restore_op_name # u'save/restore_all'
restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name # u'save/Const'
sess.run(restore_op, {filename_tensor_name: fn})
For a working and complete version see here.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With