Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

validation during training of Estimator

Tags:

tensorflow

With the TensorFlow r1.3 monitors are deprecated:

"2016-12-05", "Monitors are deprecated. Please use tf.train.SessionRunHook.") and Estimator.train(input_fn,hooks,..) works only with hooks.

How to implement the functionality of validation monitor with hooks?

like image 592
Aga Avatar asked Jul 31 '17 13:07

Aga


People also ask

How do you validate an estimate?

The analyst validating an estimate should employ different methods, tools and separately collected data than were used in the estimate under review. When reviewing an estimate you must assess the assumptions made during the estimation process. Make sure that the adopted ground rules are consistently applied throughout the estimate.

What is a validation sample in machine learning?

In the machine learning literature, the term "validation sample" is sometimes used with a different meaning: what we called above training and validation samples are collectively called a training sample (because model selection is seen as just another form of training); what we called above a test sample is called a validation sample.

How to add validation_X and validation_y data in the training process?

But there is no way to add another validation_x, and validation_y data in the training process. Therefore, when the training started, I can only see the training loss. The training loss is expected to decrease when the training process running longer. However, this information is not helpful to prevent overfitting.

What is cross-validation in machine learning?

That cross validation is a procedure used to avoid overfitting and estimate the skill of the model on new data. There are common tactics that you can use to select the value of k for your dataset. There are commonly used variations on cross-validation, such as stratified and repeated, that are available in scikit-learn.


2 Answers

EDIT: As pointed out in the comments, this feels like the right thing to do, but will reinitialize the weights every time it's evaluated, which makes it pretty much useless...


I ended up being able to monitor my validation error (which is what I understand you are trying to do) using the train_and_evaluate function. The EvalSpec object you have to use has parameters start_delay_secs and throttle_secs to define the frequency at which the error (or whatever you have defined in your estimator's EVAL mode) will be computed.

My code looks somewhat like

classifier = tf.estimator.Estimator(
    model_fn=model_fn,
    model_dir=model_dir,
    params=params)

train_spec = tf.estimator.TrainSpec(
    input_fn = input_fn,
)

eval_spec = tf.estimator.EvalSpec(
    input_fn = valid_input_fn,
    throttle_secs=120,
    start_delay_secs=120,
)

tf.estimator.train_and_evaluate(
    classifier,
    train_spec,
    eval_spec
)
like image 192
simlmx Avatar answered Oct 22 '22 00:10

simlmx


I have been using SummarySaverHook instead of Monitors. They are not "as powerful" just yet, and the training material has not been updated with a description on how exactly replicate the Monitor functionality.

Here is how i use it:

summary_hook = tf.train.SummarySaverHook(
    save_steps=SAVE_EVERY_N_STEPS,
    output_dir='./tmp/rnnStats',
    scaffold=tf.train.Scaffold(),
    summary_op=tf.summary.merge_all())

print("Classifier.train")
classifier.train(input_fn=train_input_fn, steps=1000, hooks=[summary_hook])
like image 34
Axen Avatar answered Oct 22 '22 00:10

Axen