Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow Different ways to Export and Run graph in C++

For importing your trained network to the C++ you need to export your network to be able to do so. After searching a lot and finding almost no information about it, it was clarified that we should use freeze_graph() to be able to do it.

Thanks to the new 0.7 version of Tensorflow, they added documentation of it.

After looking into documentations, I found that there are few similar methods, can you tell what is the difference between freeze_graph() and: tf.train.export_meta_graph as it has similar parameters, but it seems it can also be used for importing models to C++ (I just guess the difference is that for using the file output by this method you can only use import_graph_def() or it's something else?)

Also one question about how to use write_graph(): In documentations the graph_def is given by sess.graph_def but in examples in freeze_graph() it is sess.graph.as_graph_def(). What is the difference between these two?

This question is related to this issue.

Thank you!

like image 544
Hamed MP Avatar asked Feb 19 '16 15:02

Hamed MP


People also ask

Is TensorFlow coded in C?

TensorFlow is written in three languages such as Python, C++, CUDA.

How graphs are stored and represented in TensorFlow?

TensorFlow uses graphs as the format for saved models when it exports them from Python. Graphs are also easily optimized, allowing the compiler to do transformations like: Statically infer the value of tensors by folding constant nodes in your computation ("constant folding").


1 Answers

Here's my solution utilizing the V2 checkpoints introduced in TF 0.12.

There's no need to convert all variables to constants or freeze the graph.

Just for clarity, a V2 checkpoint looks like this in my directory models:

checkpoint  # some information on the name of the files in the checkpoint my-model.data-00000-of-00001  # the saved weights my-model.index  # probably definition of data layout in the previous file my-model.meta  # protobuf of the graph (nodes and topology info) 

Python part (saving)

with tf.Session() as sess:     tf.train.Saver(tf.trainable_variables()).save(sess, 'models/my-model') 

If you create the Saver with tf.trainable_variables(), you can save yourself some headache and storage space. But maybe some more complicated models need all data to be saved, then remove this argument to Saver, just make sure you're creating the Saver after your graph is created. It is also very wise to give all variables/layers unique names, otherwise you can run in different problems.

Python part (inference)

with tf.Session() as sess:     saver = tf.train.import_meta_graph('models/my-model.meta')     saver.restore(sess, tf.train.latest_checkpoint('models/'))     outputTensors = sess.run(outputOps, feed_dict=feedDict) 

C++ part (inference)

Note that checkpointPath isn't a path to any of the existing files, just their common prefix. If you mistakenly put there path to the .index file, TF won't tell you that was wrong, but it will die during inference due to uninitialized variables.

#include <tensorflow/core/public/session.h> #include <tensorflow/core/protobuf/meta_graph.pb.h>  using namespace std; using namespace tensorflow;  ... // set up your input paths const string pathToGraph = "models/my-model.meta" const string checkpointPath = "models/my-model"; ...  auto session = NewSession(SessionOptions()); if (session == nullptr) {     throw runtime_error("Could not create Tensorflow session."); }  Status status;  // Read in the protobuf graph we exported MetaGraphDef graph_def; status = ReadBinaryProto(Env::Default(), pathToGraph, &graph_def); if (!status.ok()) {     throw runtime_error("Error reading graph definition from " + pathToGraph + ": " + status.ToString()); }  // Add the graph to the session status = session->Create(graph_def.graph_def()); if (!status.ok()) {     throw runtime_error("Error creating graph: " + status.ToString()); }  // Read weights from the saved checkpoint Tensor checkpointPathTensor(DT_STRING, TensorShape()); checkpointPathTensor.scalar<std::string>()() = checkpointPath; status = session->Run(         {{ graph_def.saver_def().filename_tensor_name(), checkpointPathTensor },},         {},         {graph_def.saver_def().restore_op_name()},         nullptr); if (!status.ok()) {     throw runtime_error("Error loading checkpoint from " + checkpointPath + ": " + status.ToString()); }  // and run the inference to your liking auto feedDict = ... auto outputOps = ... std::vector<tensorflow::Tensor> outputTensors; status = session->Run(feedDict, outputOps, {}, &outputTensors); 
like image 52
Martin Pecka Avatar answered Oct 01 '22 04:10

Martin Pecka