Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to restore Tensorflow model from .pb file in python?

Tags:

I have an tensorflow .pb file which I would like to load into python DNN, restore the graph and get the predictions. I am doing this to test out whether the .pb file created can make the predictions similar to the normal Saver.save() model.

My basic problem is am getting a very different value of predictions when I make them on Android using the above mentioned .pb file

My .pb file creation code:

frozen_graph = tf.graph_util.convert_variables_to_constants(         session,         session.graph_def,         ['outputLayer/Softmax']     ) with open('frozen_model.pb', 'wb') as f:   f.write(frozen_graph.SerializeToString()) 

So I have two major concerns:

  1. How can I load the above mentioned .pb file to python Tensorflow model ?
  2. Why am I getting completely different values of prediction in python and android ?
like image 761
vizsatiz Avatar asked May 31 '18 20:05

vizsatiz


People also ask

How do I restore a saved model in TensorFlow?

Restoring Models The first thing to do when restoring a TensorFlow model is to load the graph structure from the ". meta" file into the current graph. The current graph could be explored using the following command tf. get_default_graph() .

How do I import a .PB file into keras?

The only thing you should do is use this code: model = tf. keras. models. load_model('./_models/vgg50_finetune') And you can both train model or use it for prediction.

What is .PB file in TensorFlow?

The . pb format is the protocol buffer (protobuf) format, and in Tensorflow, this format is used to hold models. Protobufs are a general way to store data by Google that is much nicer to transport, as it compacts the data more efficiently and enforces a structure to the data.


1 Answers

The following code will read the model and print out the names of the nodes in the graph.

import tensorflow as tf from tensorflow.python.platform import gfile GRAPH_PB_PATH = './frozen_model.pb' with tf.Session() as sess:    print("load graph")    with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:        graph_def = tf.GraphDef()    graph_def.ParseFromString(f.read())    sess.graph.as_default()    tf.import_graph_def(graph_def, name='')    graph_nodes=[n for n in graph_def.node]    names = []    for t in graph_nodes:       names.append(t.name)    print(names) 

You are freezing the graph properly that is why you are getting different results basically weights are not getting stored in your model. You can use the freeze_graph.py (link) for getting a correctly stored graph.

like image 72
Pranjal Sahu Avatar answered Oct 09 '22 11:10

Pranjal Sahu