Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Early stopping with tf.estimator, how?

I'm using tf.estimator in TensorFlow 1.4 and tf.estimator.train_and_evaluate is great but I need early stopping. What's the prefered way of adding that?

I assume there is some tf.train.SessionRunHook somewhere for this. I saw that there was an old contrib package with a ValidationMonitor that seemed to have early stopping, but it doesn't seem to be around anymore in 1.4. Or will the preferred way in the future be to rely on tf.keras (with which early stopping is really easy) instead of tf.estimator/tf.layers/tf.data, perhaps?

like image 631
Carl Thomé Avatar asked Nov 06 '17 12:11

Carl Thomé


People also ask

When should I stop TensorFlow training?

Training will stop if the model doesn't show improvement over the baseline. Whether to restore model weights from the epoch with the best value of the monitored quantity. If False, the model weights obtained at the last step of training are used.

What is TF estimate?

Model , an estimator is a model-level abstraction. The tf. estimator provides some capabilities currently still under development for tf. keras .

What is the usage of TF estimator estimator in TensorFlow?

Used in the notebooks The Estimator object wraps a model which is specified by a model_fn , which, given inputs and a number of other parameters, returns the ops necessary to perform training, evaluation, or predictions. All outputs (checkpoints, event files, etc.)


3 Answers

Good news! tf.estimator now has early stopping support on master and it looks like it will be in 1.10.

estimator = tf.estimator.Estimator(model_fn, model_dir)  os.makedirs(estimator.eval_dir())  # TODO This should not be expected IMO.  early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(     estimator,     metric_name='loss',     max_steps_without_decrease=1000,     min_steps=100)  tf.estimator.train_and_evaluate(     estimator,     train_spec=tf.estimator.TrainSpec(train_input_fn, hooks=[early_stopping]),     eval_spec=tf.estimator.EvalSpec(eval_input_fn)) 
like image 147
Carl Thomé Avatar answered Oct 13 '22 03:10

Carl Thomé


First, you must name the loss to make it available to the early stopping call. If your loss variable is named "loss" in the estimator, the line

copyloss = tf.identity(loss, name="loss") 

right beneath it will work.

Then, create a hook with this code.

class EarlyStopping(tf.train.SessionRunHook):     def __init__(self,smoothing=.997,tolerance=.03):         self.lowestloss=float("inf")         self.currentsmoothedloss=-1         self.tolerance=tolerance         self.smoothing=smoothing     def before_run(self, run_context):         graph = ops.get_default_graph()         #print(graph)         self.lossop=graph.get_operation_by_name("loss")         #print(self.lossop)         #print(self.lossop.outputs)         self.element = self.lossop.outputs[0]         #print(self.element)         return tf.train.SessionRunArgs([self.element])     def after_run(self, run_context, run_values):         loss=run_values.results[0]         #print("loss "+str(loss))         #print("running average "+str(self.currentsmoothedloss))         #print("")         if(self.currentsmoothedloss<0):             self.currentsmoothedloss=loss*1.5         self.currentsmoothedloss=self.currentsmoothedloss*self.smoothing+loss*(1-self.smoothing)         if(self.currentsmoothedloss<self.lowestloss):             self.lowestloss=self.currentsmoothedloss         if(self.currentsmoothedloss>self.lowestloss+self.tolerance):             run_context.request_stop()             print("REQUESTED_STOP")             raise ValueError('Model Stopping because loss is increasing from EarlyStopping hook') 

this compares an exponentially smoothed loss validation with its lowest value, and if it is higher by tolerance, it stops training. If it stops too early, raising tolerance and smoothing will make it stop later. Keep smoothing below one, or it will never stop.

You can replace the logic in after_run with something else if you want to stop based on a different condition.

Now, add this hook to the evaluation spec. Your code should look something like this:

eval_spec=tf.estimator.EvalSpec(input_fn=lambda:eval_input_fn(batchsize),steps=100,hooks=[EarlyStopping()])# 

Important note: The function, run_context.request_stop() is broken in the train_and_evaluate call, and doesn't stop training. So, I raised a value error to stop training. So you have to wrap the train_and_evaluate call in a try catch block like this:

try:     tf.estimator.train_and_evaluate(classifier,train_spec,eval_spec) except ValueError as e:     print("training stopped") 

if you don't do this, the code will crash with an error when training stops.

like image 30
user3806120 Avatar answered Oct 13 '22 02:10

user3806120


Yes, there is tf.train.StopAtStepHook:

This hook requests stop after either a number of steps have been executed or a last step has been reached. Only one of the two options can be specified.

You can also extend it and implement your own stopping strategy based on the step results.

class MyHook(session_run_hook.SessionRunHook):
  ...
  def after_run(self, run_context, run_values):
    if condition:
      run_context.request_stop()
like image 30
Maxim Avatar answered Oct 13 '22 04:10

Maxim