Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Freezing graph to pb in Tensorflow2

We deploy lot of our models from TF1 by saving them through graph freezing:

tf.train.write_graph(self.session.graph_def, some_path)

# get graph definitions with weights
output_graph_def = tf.graph_util.convert_variables_to_constants(
        self.session,  # The session is used to retrieve the weights
        self.session.graph.as_graph_def(),  # The graph_def is used to retrieve the nodes
        output_nodes,  # The output node names are used to select the usefull nodes
)

# optimize graph
if optimize:
    output_graph_def = optimize_for_inference_lib.optimize_for_inference(
            output_graph_def, input_nodes, output_nodes, tf.float32.as_datatype_enum
    )

with open(path, "wb") as f:
    f.write(output_graph_def.SerializeToString())

and then loading them through:

with tf.Graph().as_default() as graph:
    with graph.device("/" + args[name].processing_unit):
        tf.import_graph_def(graph_def, name="")
            for key, value in inputs.items():
                self.input[key] = graph.get_tensor_by_name(value + ":0")

We would like to save TF2 models in similar way. One protobuf file which will include graph and weights. How can I achieve this?

I know that there are some methods for saving:

  • keras.experimental.export_saved_model(model, 'path_to_saved_model')

    Which is experimental and creates multiple files :(.

  • model.save('path_to_my_model.h5')

    Which saves h5 format :(.

  • tf.saved_model.save(self.model, "test_x_model")

    Which agains save multiple files :(.

like image 684
Cospel Avatar asked Sep 26 '19 14:09

Cospel


People also ask

What is freezing graph in Tensorflow?

Freezing is the process to identify and save all of required things(graph, weights etc) in a single file that you can easily use. A typical Tensorflow model contains 4 files: model-ckpt. meta: This contains the complete graph.

How do I freeze a Tensorflow model?

Let's explore the different steps we have to perform: Retrieve our saved graph: we need to load the previously saved meta-graph in the default graph and retrieve its graph_def (the ProtoBuf definition of our graph) Restore the weights: we start a Session and restore the weights of our graph inside that Session.

What does TF function do?

You can use tf. function to make graphs out of your programs. It is a transformation tool that creates Python-independent dataflow graphs out of your Python code. This will help you create performant and portable models, and it is required to use SavedModel .

What are graphs in Tensorflow?

Graphs are data structures that contain a set of tf. Operation objects, which represent units of computation; and tf. Tensor objects, which represent the units of data that flow between operations. They are defined in a tf. Graph context.


1 Answers

the above code is a little old. when convert vgg16, it could succeed, but it failed when convert resnet_v2_50 model. my tf version is tf 2.2.0 finally, I found a useful code snippet:

import tensorflow as tf
from tensorflow import keras
from tensorflow.python.framework.convert_to_constants import     convert_variables_to_constants_v2
import numpy as np


#set resnet50_v2 as a example
model = tf.keras.applications.ResNet50V2()
 
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
    tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()
 
layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
    print(layer)
 
print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)
 
# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                  logdir="./frozen_models",
                  name="frozen_graph.pb",
                  as_text=False)

ref: https://github.com/leimao/Frozen_Graph_TensorFlow/tree/master/TensorFlow_v2 (update)

like image 110
zhenglin Li Avatar answered Oct 27 '22 12:10

zhenglin Li