Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow: Is there a way to convert a frozen graph into a checkpoint model?

Converting a checkpoint model into a frozen graph is possible (.ckpt file to .pb file). However, is there a reverse method of converting a pb file into a checkpoint file once again?

I'd imagine it requires a conversion of the constants back into a variable - is there a way to identify the correct constants as variables and restore them back into a checkpoint model?

Currently there is support for conversion of variables to constants here: https://www.tensorflow.org/api_docs/python/tf/graph_util/convert_variables_to_constants

but not the other way round.

A similar question has been raised here: Tensorflow: Convert constant tensor from pre-trained Vgg model to variable

But the solution relies on using a ckpt model to restore weight variables. Is there a way to restore weight variables from PB files instead of a checkpoint file? This could be useful for weight pruning.

like image 904
kwotsin Avatar asked Jul 24 '17 08:07

kwotsin


People also ask

What does TFLite converter do?

The TensorFlow Lite converter takes a TensorFlow model and generates a TensorFlow Lite model (an optimized FlatBuffer format identified by the . tflite file extension). You can load a SavedModel or directly convert a model you create in code.

What is frozen inference graph in TensorFlow?

Freezing is the process to identify and save all of required things(graph, weights etc) in a single file that you can easily use. A typical Tensorflow model contains 4 files: model-ckpt. meta: This contains the complete graph. [This contains a serialized MetaGraphDef protocol buffer.


2 Answers

There is a method for converting constants back to trainable variables in TensorFlow, via the Graph Editor. However, you will need to specify the nodes to convert, as I'm not sure if there is a way to automatically detect this in a robust manner.

Here are the steps:

Step 1: Load frozen graph

We load our .pb file into a graph object.

import tensorflow as tf

# Load protobuf as graph, given filepath
def load_pb(path_to_pb):
    with tf.gfile.GFile(path_to_pb, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name='')
        return graph

tf_graph = load_pb('frozen_graph.pb')

Step 2: Find constants that need conversion

Here are 2 ways to list the names of nodes in your graph:

  • Use this script to print them
  • print([n.name for n in tf_graph.as_graph_def().node])

The nodes you'll want to convert are likely named something along the lines of "Const". To be sure, it is a good idea to load your graph in Netron to see which tensors are storing the trainable weights. Oftentimes, it is safe to assume that all const nodes were once variables.

Once you have these nodes identified, let's store their names into a list:

to_convert = [...] # names of tensors to convert

Step 3: Convert constants to variables

Run this code to convert your specified constants. It essentially creates corresponding variables for each constant and uses GraphEditor to unhook the constants from the graph, and hook the variables on.

import numpy as np
import tensorflow as tf
import tensorflow.contrib.graph_editor as ge

const_var_name_pairs = []
with tf_graph.as_default() as g:

    for name in to_convert:
        tensor = g.get_tensor_by_name('{}:0'.format(name))
        with tf.Session() as sess:
            tensor_as_numpy_array = sess.run(tensor)
        var_shape = tensor.get_shape()
        # Give each variable a name that doesn't already exist in the graph
        var_name = '{}_turned_var'.format(name)
        # Create TensorFlow variable initialized by values of original const.
        var = tf.get_variable(name=var_name, dtype='float32', shape=var_shape, \  
                      initializer=tf.constant_initializer(tensor_as_numpy_array))
        # We want to keep track of our variables names for later.
        const_var_name_pairs.append((name, var_name))

    # At this point, we added a bunch of tf.Variables to the graph, but they're
    # not connected to anything.

    # The magic: we use TF Graph Editor to swap the Constant nodes' outputs with
    # the outputs of our newly created Variables.

    for const_name, var_name in const_var_name_pairs:
        const_op = g.get_operation_by_name(const_name)
        var_reader_op = g.get_operation_by_name(var_name + '/read')
        ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))

Step 4: Save result as .ckpt

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        save_path = tf.train.Saver().save(sess, 'model.ckpt')
        print("Model saved in path: %s" % save_path)

And viola! You should be done at this point :) I was able to get this working myself, and verified that the model weights are preserved--the only difference is that the graph is now trainable. Please let me know if there are any issues.

like image 198
Max Wu Avatar answered Oct 03 '22 02:10

Max Wu


If you have the source code that built the network it can be done relatively easy because the name of the Convolutions/Fully connected didn't changed by the freeze graph method, so you can basically investigate the graph and match the constants operations to their variables matches and just load the variables with the constants value. -- by Almog David

Thanks to @Almog David's excellent answer above; I was facing the exact same situation that

  • I have frozen_inference_graph.pb but not checkpoints;
  • I have the source code to produce frozen_inference_graph.pb but I don't know the parameters.

and below is the three steps to solve the dilemma.

1. Get pairs of node names and values from frozen_inference_graph.pb

import tensorflow as tf
from tensorflow.python.framework import tensor_util

def get_node_values(old_graph_path):
    old_graph = tf.Graph()
    with old_graph.as_default():
        old_graph_def = tf.GraphDef()
        with tf.gfile.GFile(old_graph_path, "rb") as fid:
            serialized_graph = fid.read()
            old_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(old_graph_def, name='')

    old_sess = tf.Session(graph=old_graph)

    # get all the nodes from the graph def
    nodes = old_sess.graph.as_graph_def().node

    value_dict = {}
    for node in nodes:
        value = node.attr['value'].tensor
        try:
            # get name and value (numpy array) from tensor 
            value_dict[node.name] = tensor_util.MakeNdarray(value) 
        except:
            # some tensor doesn't have value; for example np.squeeze
            # just ignore it 
            pass
    return value_dict

value_dict = get_node_values("frozen_inference_graph.pb")

2. Create new graph using existing code; tune the model parameters until all the nodes in new graph are present in the value_dict

new_graph = tf.Graph()
with new_graph.as_default():
    tf.create_global_step()
    #existing code 
    # ...
    # ...
    # ...

    model_variables = tf.model_variables()
    unseen_variables = set(model_variable.name[:-2] for model_variable in model_variables) - set(value_dict.keys())
    print  ("\n".join(sorted(list(unseen_variables))))

3.Assign values to variables and save to checkpoint (or save to graph)

new_graph_path = "model.ckpt"
saver = tf.train.Saver(model_variables)

assign_ops = []
for variable in model_variables:
    print ("Assigning", variable.name[:-2])
    # variable names have ":0" but constant names doesn't have.
    value = value_dict[variable.name[:-2]]
    assign_ops.append(variable.assign(value))

sess =session.Session(graph = new_graph)
sess.run(tf.global_variables_initializer())
sess.run(assign_ops)
saver.save(sess, new_graph_path+"model.ckpt")

This is the only way I could think of to solve this problem. However, it still exists some drawbacks: if you reload the model checkpoints, you will find (along all useful variables) a lot of unwanted assign variables such as Assign_700/value. This is unavoidable and looks ugly. If you have better suggestions, feel free to comment. Thanks.

like image 35
Yoyo Avatar answered Oct 03 '22 04:10

Yoyo