Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow Estimator - warm_start_from and model_dir

When using tf.estimator with warm_start_from and model_dir, and both warm_start_from directory and model_dir directory contain valid checkpoints, which checkpoint will be actually restored?

To give some context, my estimator code looks like

est = tf.estimator.Estimator(
    model_fn=model_fn,
    model_dir=model_dir,
    warm_start_from=warm_start_dir)

for epoch in range(num_epochs):
    est.train(input_fn=train_input_fn)
    est.evaluate(input_fn=eval_input_fn)

(Input functions use one shot iterators.)

So during the first iteration, when model_dir is empty, I want the warm start checkpoint to be loaded, but in the next epoch, i'd like to have the intermediate fine-tuned checkpoint from the last iteration in model_dir to be loaded. But at least from the logs, it looks like warm_start_dir is still being loaded.

I could probably override my estimator for the next iterations but I wonder if it shouldn't be built in the estimator some how.

like image 580
mtngld Avatar asked Apr 15 '18 19:04

mtngld


1 Answers

I've had a similar issue, I've solved this by providing an initialization hook that is run when the session is started, and using tf.estimator.train_and_evaluate (though I can't take credit for this whole solution, as I saw something similar for another purpose elsewhere):

class InitHook(tf.train.SessionRunHook):
    """initializes model from a checkpoint_path
    args:
        modelPath: full path to checkpoint
    """
    def __init__(self, checkpoint_dir):
        self.modelPath = checkpoint_dir
        self.initialized = False

    def begin(self):
        """
        Restore encoder parameters if a pre-trained encoder model is available and we haven't trained previously
        """
        if not self.initialized:
            log = logging.getLogger('tensorflow')
            checkpoint = tf.train.latest_checkpoint(self.modelPath)
            if checkpoint is None:
                log.info('No pre-trained model is available, training from scratch.')
            else:
                log.info('Pre-trained model {0} found in {1} - warmstarting.'.format(checkpoint, self.modelPath))
                tf.train.warm_start(checkpoint)
            self.initialized = True

Then, for training:

initHook = InitHook(checkpoint_dir = warm_start_dir)
trainSpec = tf.estimator.TrainSpec(
    input_fn = train_input_fn,
    max_steps = N_STEPS, 
    hooks = [initHook]
)
evalSpec = tf.estimator.EvalSpec(
    input_fn = eval_input_fn,
    steps = None,
    name = 'eval',
    throttle_secs = 3600
)
tf.estimator.train_and_evaluate(estimator, trainSpec, evalSpec)

This runs once at the beginning to initialize variables from warm_start_dir. Later, when there are new checkpoints in the estimator model_dir, it continues warm_starting from there.

like image 95
kamyonet Avatar answered Jan 03 '23 20:01

kamyonet