In tensorflow it is fairly easy to load trained models back into tensorflow through the use of checkpoints. However, this use case seems oriented towards users that want to either run evaluation or additional training on a checkpointed model.
What is the simplest way in tensorflow to load a pre-trained model and use it (without training) to produce results which will then be used in a new model?
Right now the methods that seem most promising are tf.get_tensor_by_name() and tf.stop_gradient() in order to get the input and output tensors for the trained model loaded from tf.train.import_meta_graph().
What is the best practices setup for this sort of thing?
The most straightforward solution would be to freeze the pre-trained model variables using this function:
def freeze_graph(model_dir, output_node_names):
"""Extract the sub graph defined by the output nodes and convert
all its variables into constant
Args:
model_dir: the root folder containing the checkpoint state file
output_node_names: a string, containing all the output node's names,
comma separated
"""
if not tf.gfile.Exists(model_dir):
raise AssertionError(
"Export directory doesn't exist")
if not output_node_names:
print("You need to supply the name of the output node")
return -1
# We retrieve our checkpoint fullpath
checkpoint = tf.train.get_checkpoint_state(model_dir)
input_checkpoint = checkpoint.model_checkpoint_path
# We precise the file fullname of our freezed graph
absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
# We clear devices to allow TensorFlow to control on which device it will load operations
clear_devices = True
# We start a session using a temporary fresh Graph
with tf.Session(graph=tf.Graph()) as sess:
# We import the meta graph in the current default Graph
saver = tf.train.import_meta_graph(args.meta_graph_path, clear_devices=clear_devices)
# We restore the weights
saver.restore(sess, input_checkpoint)
# We use a built-in TF helper to export variables to constants
frozen_graph = tf.graph_util.convert_variables_to_constants(
sess, # The session is used to retrieve the weights
tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes
output_node_names.split(",") # The output node names are used to select the usefull nodes
)
return frozen_graph
Then you'd be able to build your new-model on top of the pre-trained model:
# Get the frozen graph
frozen_graph = freeze_graph(YOUR_MODEL_DIR, YOUR_OUTPUT_NODES)
# Set the frozen graph as a default graph
frozen_graph.as_default()
# Get the output tensor from the pre-trained model
pre_trained_model_result = frozen_graph.get_tensor_by_name(OUTPUT_TENSOR_NAME_OF_PRETRAINED_MODEL)
# Let's say you want to get the pre trained model result's square root
my_new_operation_results = tf.sqrt(pre_trained_model_result)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With