How do I save and load a DNN classifier in tensorflow? Asking this for the default Iris classifier program given. (https://www.tensorflow.org/get_started/estimator)
To save and reuse the classifier you can just reload it with the same model_dir path.
For example in the method you want to use the classifier you can just create the classifier again with the same model_dir. This will reload it from what ever state it was previously.
I use this for training and then reload it for testing single examples.
tf.estimator.DNNClassifier
(feature_columns=feature_columns,
hidden_units=[10, 20, 10],
n_classes=3,
model_dir="/tmp/iris_model")
The first thing you need to do is to create a tensorflow Saver object inside your session:
with tf.Session(graph=graph) as sess:
saver = tf.train.Saver()
Then, after your training - and still inside the session -, you call the save method:
saver.save(sess, 'path/to/model_file')
You don't need to specify file extension since the save method will do it for you.
To restore the model, you open a new session (without a graph, of course) and do like this:
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('path/to/model_file.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('path/to/model_dir/'))
# restore the tensors you want (usually, the ones you use in feed_dict and sess.run)
graph = tf.get_default_graph()
x = graph.get_tensor_by_name("x:0")
output = graph.get_tensor_by_name("output:0")
feed_dict = {x:x}
[result] = sess.run([output], feed_dict=feed_dict)
You can also check this tutorial about saving and restoring tensorflow models. I hope it helps!
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With