Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is it possible to restore a tensorflow estimator from saved model?

I use tf.estimator.train_and_evaluate() to train my custom estimator. My dataset is partitioned 8:1:1 for training, evaluation and test. At the end of the training, I would like to restore the best model, and evaluate the model using tf.estimator.Estimator.evaluate() with the test data. The best model is currently exported using tf.estimator.BestExporter.

While tf.estimator.Estimator.evaluate() accepts checkpoint_path and restores variables, I cannot find any easy way to use the exported model generated by tf.estimator.BestExporter. I could of course keep all checkpoints during training, and look for the best model by myself, but that seems quite suboptimal.

Could anyone tell me an easy workaround? Maybe it is possible to convert a saved model to a checkpoint?

like image 290
Taro Kiritani Avatar asked Nov 07 '18 05:11

Taro Kiritani


4 Answers

Maybe you can try tf.estimator.WarmStartSettings: https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/estimator/WarmStartSettings

It can load weights in pb file and continue training, which worked on my project.

You can set warm-start as follows:

ws = WarmStartSettings(ckpt_to_initialize_from="/[model_dir]/export/best-exporter/[timestamp]/variables/variables")

And then everythring will be Ok

like image 135
liuchao Avatar answered Oct 22 '22 06:10

liuchao


Based on the resolution to @SumNeuron's Github issue tf.contrib.estimator.SavedModelEstimator is the way to load from a saved model to an Estimator.

The following works for me:

estimator = tf.contrib.estimator.SavedModelEstimator(saved_model_dir)
prediction_results = estimator.predict(input_fn)

Baffling that this is essentially completely undocumented.

like image 45
DavidS Avatar answered Oct 22 '22 08:10

DavidS


I am also new to the Estimator API but I think I know what you are looking for, although it is equally annoying.

From this colab, which is a toy custom Estimator with some bells and whistles added on:

from tensorflow.contrib import predictor
predict_fn = predictor.from_saved_model(<model_dir>)
predict_fn(pred_features) # pred_features corresponds to your input features

and this estimator both uses a BestExporter

exporter = tf.estimator.BestExporter(
    name="best_exporter",
    serving_input_receiver_fn=serving_input_receiver_fn,
    exports_to_keep=5
) # this will keep the 5 best checkpoints

as well as just exports the model after training:

est.export_savedmodel('./here', serving_input_receiver_fn)

If it irks you that the Estimator API has no "proper" way to load a SavedModel, I already created an issue on GitHub.

However, if you are trying to load it onto a different device, see my other questions:

  • TensorFlow v1.10+ load SavedModel with different device placement or manually set dynamic device placement?

  • TensorFlow Estimator clear_deivces in exporters?

which address device placement, for which there are other GitHub issues

  • No clear_devices in BestExporter #23900

  • Relative Device Placement #23834

In short, at the moment, the device you train on is the device you MUST have available when you load your Estimator if you export using the Estimator exporters. If you manually export your Estimator in the model_fn if you set clear_devices, then you should be good to go. At the moment there does not seem to be a way to change this after you export your model.

like image 1
SumNeuron Avatar answered Oct 22 '22 07:10

SumNeuron


Hope someone else will find a cleaner way..

tf.estimator.BestExporter exports the best model like this:

<your_estimator.model_dir>
+--export
   +--best_exporter
      +--xxxxxxxxxx(timestamp)
         +--saved_model.pb
         +--variables
            +--variables.data-00000-of-00001
            +--variables.index

On the other hand, in your_estimator.model_dir, checkpoints are stored in three files.

model.ckpt-xxxx.data-00000-of-00001
model.ckpt-xxxx.index
model.ckpt-xxxx.meta

First, I used tf.estimator.Estimator.evaluate(..., checkpoint_path='<your_estimator.model_dir>/export/best_exporter/<xxxxxxxxxx>/variables/variables'), but this did not work.

After copying one of the metafiles in your_estimator.model_dir, and renaming it "variables.meta", the evaluation seemed to work all right.

like image 1
Taro Kiritani Avatar answered Oct 22 '22 08:10

Taro Kiritani