Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is it possible to convert a trained model in TensorFlow to an object that could be used for transfer learning?

I am wondering to use the transfer learning as described in here: https://www.tensorflow.org/tutorials/images/transfer_learning

The issue is that the model I am trying to use as a base model that is not one of the known built-in Keras models such as MobileNetV2. Thus, I guess I need to do the following first step (step 1) to be able to do what it is mentioned in the tutorial for transfer learning (steps 2-6).
1. Load the model from the directory that includes the Saved_Model files.
2. Freeze the model (make it's trainable parameters unchangeable)
3. Make a separate layer and stack it on top of the frozen model
4. Train the resulting model.
5. Save the newly trained model.
6. Do predictions using the newly trained model.

My question is regarding the first step. I get an error that I do not understand how to fix it when trying to load the model using the following Python codes/scripts:

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np
tf.saved_model.load(
    export_dir='/dir_to_the_model_files/', tags=None
)

The error is:

OSError: Cannot parse file b'/dir_to_the_model_files/saved_model.pbtxt': 1:1 : Message type "tensorflow.SavedModel" has no field named "node"..

I am also thinking that there might be a way to convert the TensorFlow files including (saved_model.ckpt-0.data-00000-of-00001) to the files that are readable with Keras API (e.g. h5py.File format) which may facilitate the transfer learning similar to the mentioned tutorial. So, I could apply a similar method to the following ones to extract the base model and do the next steps.

base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')

Or preferably use the following method from https://www.tensorflow.org/api_docs/python/tf/keras/models/load_model:

tf.keras.models.load_model(
    filepath, custom_objects=None, compile=True
)

Update: I tried the following method but it does not work (tf was imported using the compatible version import tensorflow.compat.v1. as tf):

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/dir_to_the_model_files/saved_model.ckpt-0.meta')
    saver.restore(sess, "/dir_to_the_model_files/saved_model.ckpt-0")
    loaded = tf.saved_model.load(sess,tags=None,export_dir="/dir_to_the_model_files",import_scope=None)

It returns the following warnings and errors:

WARNING:tensorflow:The saved meta_graph is possibly from an older release:
'metric_variables' collection should be of type 'byte_list', but instead is of type 'node_list'.
INFO:tensorflow:Restoring parameters from /dir_to_the_model_files/saved_model.ckpt-0
<tensorflow.python.training.saver.Saver object at 0x2aaab4824a50>
WARNING:tensorflow:From <ipython-input-3-b8fd24f6b841>:9: load (from tensorflow.python.saved_model.loader_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.loader.load or tf.compat.v1.saved_model.load. There will be a new function for importing SavedModels in Tensorflow 2.0.

OSError: Cannot parse file b'/dir_to_the_model_files/saved_model.pbtxt': 1:1 : Message type "tensorflow.SavedModel" has no field named "node"..
like image 349
Remy Avatar asked Dec 10 '25 22:12

Remy


1 Answers

The TensorFlow documentation for tf.saved_model.load might help:

SavedModels from tf.estimator.Estimator or 1.x SavedModel APIs have a flat graph instead of tf.function objects. These SavedModels will have functions corresponding to their signatures in the .signatures attribute, but also have a .prune method which allows you to extract functions for new subgraphs. This is equivalent to importing the SavedModel and naming feeds and fetches in a Session from TensorFlow 1.x.

You might have to use deprecated v1 api call https://www.tensorflow.org/api_docs/python/tf/compat/v1/saved_model/load

like image 91
Brian Spiering Avatar answered Dec 13 '25 11:12

Brian Spiering