To share our trained tensorflow networks, we freeze the graph into a .pb
file. We also create an xml file with some metadata such as the input tensors and output tensors, type of pre-processing to apply, training data information etc. The models are then served using Java or C# by loading the graph and evaluating the tensors etc.
To make sharing easier, I would like to include this xml data somewhere in the .pb
file. Is there any way to do this? One idea would be to have it as a tf.Constant, but I don't see how I could connect it to the normal graph.
Note this is using freeze_graph.py
. Is the new SavedModel format more suitable?
First of all, yes you should use the new SavedModel format, as it is what will be supported by the TF team going forwards, and works with Keras as well. You can add an additional endpoint to the model, that returns a constant tensor (as you mention) with a string of your XML data.
This is good because it's hermetic -- the underlying savemodel format does not matter, because your metadata is saved in the computation graph itself.
See the answer to this question: Saving a TF2 keras model with custom signature defs . That answer doesn't get you 100% of the way there for Keras, because it doesn't interop nicely with the tf.keras.models.load function, as they wrap it inside a tf.Module
. Luckily, using tf.keras.Model
works as well in TF2, if you add a tf.function decorator:
class MyModel(tf.keras.Model):
def __init__(self, metadata, **kwargs):
super(MyModel, self).__init__(**kwargs)
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.metadata = tf.constant(metadata)
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
@tf.function(input_signature=[])
def get_metadata(self):
return self.metadata
model = MyModel('metadata_test')
input_arr = tf.random.uniform((5, 5, 1)) # This call is needed so Keras knows its input shape. You could define manually too
outputs = model(input_arr)
Then you can save and load your model as follows:
tf.keras.models.save_model(model, 'test_model_keras')
model_loaded = tf.keras.models.load_model('test_model_keras')
And finally use model_loaded.get_metadata()
to retrieve your constant metadata tensor.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With