Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Replace Validation Monitors with tf.train.SessionRunHook when using Estimators

Tags:

I am running a DNNClassifier, for which I am monitoring accuracy while training. monitors.ValidationMonitor from contrib/learn has been working great, in my implementation I define it:

validation_monitor = skflow.monitors.ValidationMonitor(input_fn=lambda: input_fn(A_test, Cl2_test), eval_steps=1, every_n_steps=50)

and then use call from:

clf.fit(input_fn=lambda: input_fn(A, Cl2),
            steps=1000, monitors=[validation_monitor])

where:

clf = tensorflow.contrib.learn.DNNClassifier(...

This works fine. That said, validation monitors appear to be deprecated and a similar functionality to be replaced with tf.train.SessionRunHook.

I am a newbie in TensorFlow, and it does not seem trivial to me how such a replacing implementation would look like. Any suggestion are highly appreciated. Again, I need to validate the training after a specific number of steps. Thanks very much in advance.

like image 409
user3807125 Avatar asked Jun 28 '17 04:06

user3807125


People also ask

How to automatically reserve part of training data for validation?

Here's another option: the argument validation_split allows you to automatically reserve part of your training data for validation. The argument value represents the fraction of the data to be reserved for validation, so it should be set to a number higher than 0 and lower than 1.

How does model training&evaluation work with built-in loops?

In general, whether you are using built-in loops or writing your own, model training & evaluation works strictly in the same way across every kind of Keras model -- Sequential models, models built with the Functional API, and models written from scratch via model subclassing.

What is a validation set in machine learning?

To compare the performance of these experiments, another random split can be extracted from the original data set, which is not used for training nor testing but to validate our model in different configurations. This is known as the validation set. Now, you might be wondering, but then, validation and test sets have the same purpose, right?

Do different model configurations have the best validation metrics?

However, when trying different model configurations to have the best validation metrics, we are in a way fitting our model to the validation set, choosing the combination of parameters with the best performance on that set.


2 Answers

I managed to come up with a way to extend tf.train.SessionRunHook as suggested.

import tensorflow as tf


class ValidationHook(tf.train.SessionRunHook):
    def __init__(self, model_fn, params, input_fn, checkpoint_dir,
                 every_n_secs=None, every_n_steps=None):
        self._iter_count = 0
        self._estimator = tf.estimator.Estimator(
            model_fn=model_fn,
            params=params,
            model_dir=checkpoint_dir
        )
        self._input_fn = input_fn
        self._timer = tf.train.SecondOrStepTimer(every_n_secs, every_n_steps)
        self._should_trigger = False

    def begin(self):
        self._timer.reset()
        self._iter_count = 0

    def before_run(self, run_context):
        self._should_trigger = self._timer.should_trigger_for_step(self._iter_count)

    def after_run(self, run_context, run_values):
        if self._should_trigger:
            self._estimator.evaluate(
                self._input_fn
            )
            self._timer.update_last_triggered_step(self._iter_count)
        self._iter_count += 1

and used it as a training_hook in Estimator.train:

estimator.train(input_fn=_input_fn(...),
                steps=num_epochs * num_steps_per_epoch,
                hooks=[ValidationHook(...)])

It doesn't have any fancy things a ValidationMonitor has like early-stopping and whatnot but this should be a start.

like image 39
Rocket Pingu Avatar answered Oct 15 '22 06:10

Rocket Pingu


There's an undocumented utility called monitors.replace_monitors_with_hooks() which converts monitors to hooks. The method accepts (i) a list which may contain both monitors and hooks and (ii) the Estimator for which the hooks will be used, and then returns a list of hooks by wrapping a SessionRunHook around each Monitor.

from tensorflow.contrib.learn.python.learn import monitors as monitor_lib

clf = tf.estimator.Estimator(...)

list_of_monitors_and_hooks = [tf.contrib.learn.monitors.ValidationMonitor(...)]
hooks = monitor_lib.replace_monitors_with_hooks(list_of_monitors_and_hooks, clf)

This isn't really a true solution to the problem of fully replacing the ValidationMonitor—we're just wrapping it up with a non-deprecated function instead. However, I can say this has worked for me so far in that it maintained all the functionality I need from the ValidationMonitor (i.e. evaluating every n steps, early stopping using a metric, etc.)

One more thing—to use this hook you'll need to update from a tf.contrib.learn.Estimator (which only accepts monitors) to the more full-fledged and official tf.estimator.Estimator (which only accepts hooks). So, you should instantiate your classifier as a tf.estimator.DNNClassifier, and train using its method train() instead (which is just a re-naming of fit()):

clf = tf.estimator.Estimator(...)

...

clf.train(
    input_fn=...
    ...
    hooks=hooks)
like image 73
lukearend Avatar answered Oct 15 '22 07:10

lukearend