Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using Tensorflow checkpoint to restore model in C++

I've trained a network that I implemented with Tensorflow using Python. In the end, I saved the model with tf.train.Saver(). And now I would like to use C++ to make predictions using this pre trained network.

How can I do that ? Is there a way to convert checkpoint so I can use it with tiny-dnn or Tensorflow C++ ?

Any idea is welcome :) thank you !

like image 764
A. Piro Avatar asked Dec 19 '22 06:12

A. Piro


1 Answers

You probably should export the model in the SavedModel format, which encapsulates the computational graph and the saved variables (tf.train.Saver only saves the variables, so you'd have to save the graph anyway).

You can then load the saved model in C++ using LoadSavedModel.

The exact invocation would depend on what the inputs and outputs of your model are. But the Python code would look something like so:

# You'd adjust the arguments here according to your model
signature = tf.saved_model.signature_def_utils.predict_signature_def(                                                                        
  inputs={'image': input_tensor}, outputs={'scores': output_tensor})                                                                         


builder = tf.saved_model.builder.SavedModelBuilder('/tmp/my_saved_model')                                                                    

builder.add_meta_graph_and_variables(                                                                                                        
   sess=sess,                                                                                                                    
   tags=[tf.saved_model.tag_constants.SERVING],                                                                                             
   signature_def_map={                                                                                                       
 tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:                                                                
        signature                                                                                                                        
})                                                                                                                                       

builder.save()

And then in C++ you'd do something like this:

tensorflow::SavedModelBundle model;
auto status = tensorflow::LoadSavedModel(session_options, run_options, "/tmp/my_saved_model", {tensorflow::kSavedModelTagServe}, &model);
if (!status.ok()) {
   std::cerr << "Failed: " << status;
   return;
}
// At this point you can use model.session

(Note that using the SavedModel format will also allow you to serve models using TensorFlow Serving, if that makes sense for your application)

Hope that helps.

like image 154
ash Avatar answered Dec 21 '22 11:12

ash