Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Saving, loading, and predicting from a TensorFlow Estimator model (2.0)

Tags:

Is there a guide anywhere for serializing and restoring Estimator models in TF2? The documentation is very spotty, and much of it not updated to TF2. I've yet to see a clear ands complete example anywhere of an Estimator being saved, loaded from disk and used to predict from new inputs.

TBH, I'm a bit baffled by how complicated this appears to be. Estimators are billed as simple, relatively high-level ways of fitting standard models, yet the process for using them in production seems very arcane. For example, when I load a model from disk via tf.saved_model.load(export_path) I get an AutoTrackable object:

<tensorflow.python.training.tracking.tracking.AutoTrackable at 0x7fc42e779f60>

Its not clear why I don't get my Estimator back. It looks like there used to be a useful-sounding function tf.contrib.predictor.from_saved_model, but since contrib is gone, it does not appear to be in play anymore (except, it appears, in TFLite).

Any pointers would be very helpful. As you can see, I'm a bit lost.

like image 707
Chris Fonnesbeck Avatar asked Nov 20 '19 16:11

Chris Fonnesbeck


1 Answers

maybe the author doesn't need the answer anymore but I was able to save and load a DNNClassifier using TensorFlow 2.1

# training.py
from pathlib import Path
import tensorflow as tf

....
# Creating the estimator
estimator = tf.estimator.DNNClassifier(
    model_dir = <model_dir>,
    hidden_units = [1000, 500],
    feature_columns = feature_columns, # this is a list defined earlier
    n_classes = 2,
    optimizer = 'adam')

feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)
export_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
servable_model_path = Path(estimator.export_saved_model(<model_dir>, export_input_fn).decode('utf8'))
print(f'Model saved at {servable_model_path}')

For loading, you found the correct method, you just need to retrieve the predict_fn

# testing.py
import tensorflow as tf
import pandas as pd

def predict_input_fn(test_df):
    '''Convert your dataframe using tf.train.Example() and tf.train.Features()'''
    examples = []
    ....
    return tf.constant(examples)

test_df = pd.read_csv('test.csv', ...)

# Loading the estimator
predict_fn = tf.saved_model.load(<model_dir>).signatures['predict']
# Predict
predictions = predict_fn(examples=predict_input_fn(test_df))

Hope that this can help other people too (:

like image 72
Omar Cotugno Avatar answered Sep 18 '22 09:09

Omar Cotugno