Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Load checkpoint and finetuning using tf.estimator.Estimator

Tags:

tensorflow

We're trying to translate old training code based into a more tf.estimator.Estimator compliant code. In the initial code we fine tune an original model for a target dataset. Only some layers are loaded from the checkpoint before the training takes place using a combination of variables_to_restore and init_fn with the MonitoredTrainingSession. How can one achieve this kind of weight loading with the tf.estimator.Estimator approach ?

like image 956
jrabary Avatar asked Sep 26 '17 10:09

jrabary


People also ask

What does TF estimator do?

The Estimator object wraps a model which is specified by a model_fn , which, given inputs and a number of other parameters, returns the ops necessary to perform training, evaluation, or predictions. All outputs (checkpoints, event files, etc.)

What kind of estimator model does TensorFlow recommend using for classification?

It is recommended using pre-made Estimators when just getting started. To write a TensorFlow program based on pre-made Estimators, you must perform the following tasks: Create one or more input functions. Define the model's feature columns.

What is Model_fn?

The “model_fn” parameter is a function that consumes the features, labels, mode and params in the following order: def model_fn(features, labels, mode, params): The Estimator will always supply those parameters when it executes the model function for training, evaluation or prediction.


1 Answers

you have two options, first one is simpler:

1- use tf.train.init_from_checkpoint in your model_fn

2- model_fn returns an EstimatorSpec. You can set scaffold viaEstimatorSpec.

like image 80
user1454804 Avatar answered Jan 03 '23 01:01

user1454804