Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to save and restore a TensorFlow graph and its state in C++?

Tags:

c++

tensorflow

I'm training my model using TensorFlow in C++. Python is used only for constructing the graph. So is there a way to save and restore the graph and its state purely in C++? I know about the Python class tf.train.Saver but as far as I understand it does not exist in C++.

like image 838
Anton Pechenko Avatar asked Dec 14 '22 06:12

Anton Pechenko


2 Answers

The tf.train.Saver class currently exists only in Python, but (i) it is built from TensorFlow ops that you can run from C++, and (ii) it exposes the Saver.as_saver_def() method that lets you get a SaverDef protocol buffer with the names of ops that you must run to save or restore a model.

In Python, you can get the names of the save and restore ops as follows:

saver = tf.train.Saver(...)
saver_def = saver.as_saver_def()

# The name of the tensor you must feed with a filename when saving/restoring.
print saver_def.filename_tensor_name

# The name of the target operation you must run when restoring.
print saver_def.restore_op_name

# The name of the target operation you must run when saving.
print saver_def.save_tensor_name

In C++ to restore from a checkpoint, you call Session::Run(), feeding in the name of the checkpoint file as saver_def.filename_tensor_name, with a target op of saver_def.restore_op_name. To save another checkpoint, you call Session::Run(), again feeding in the name of the checkpoint file as saver_def.filename_tensor_name, and fetching the value of saver_def.save_tensor_name.

like image 191
mrry Avatar answered Dec 29 '22 00:12

mrry


The recent TensorFlow version includes some helper functions to do the same in C++ without Python. These are generate from the ProtoBuf in the pip-package (${HOME}/.local/lib/python2.7/site-packages/tensorflow/include/tensorflow/core/protobuf/saver.pb.h).

// save
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = {{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}};
status = sess->Run(feed_dict, {}, {graph_def.saver_def().save_tensor_name()}, nullptr);

// restore
tensorflow::Tensor checkpointPathTensor(tensorflow::DT_STRING, tensorflow::TensorShape());
checkpointPathTensor.scalar<std::string>()() = "some/path";
tensor_dict feed_dict = {{graph_def.saver_def().filename_tensor_name(), checkpointPathTensor}};
status = sess->Run(feed_dict, {}, {graph_def.saver_def().restore_op_name()}, nullptr);

This is based on the undocumented python-way (more details) of restoring a model

def restore(sess, metaGraph, fn):
    restore_op_name = metaGraph.as_saver_def().restore_op_name   # u'save/restore_all'
    restore_op = tf.get_default_graph().get_operation_by_name(restore_op_name)
    filename_tensor_name = metaGraph.as_saver_def().filename_tensor_name  # u'save/Const'
    sess.run(restore_op, {filename_tensor_name: fn})

For a working and complete version see here.

like image 29
Patwie Avatar answered Dec 29 '22 01:12

Patwie