Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Nested cross-validation: How does cross_validate handle GridSearchCV as its input estimator?

The following code combines cross_validate with GridSearchCV to perform a nested cross-validation for an SVC on the iris dataset.

(Modified example of the following documentation page: https://scikit-learn.org/stable/auto_examples/model_selection/plot_nested_cross_validation_iris.html#sphx-glr-auto-examples-model-selection-plot-nested-cross-validation-iris-py.)


from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV, cross_validate, KFold
import numpy as np
np.set_printoptions(precision=2)

# Load the dataset
iris = load_iris()
X_iris = iris.data
y_iris = iris.target

# Set up possible values of parameters to optimize over
p_grid = {"C": [1, 10],
          "gamma": [.01, .1]}

# We will use a Support Vector Classifier with "rbf" kernel
svm = SVC(kernel="rbf")

# Choose techniques for the inner and outer loop of nested cross-validation
inner_cv = KFold(n_splits=5, shuffle=True, random_state=1)
outer_cv = KFold(n_splits=4, shuffle=True, random_state=1)

# Perform nested cross-validation
clf = GridSearchCV(estimator=svm, param_grid=p_grid, cv=inner_cv, iid=False)
clf.fit(X_iris, y_iris)
best_estimator = clf.best_estimator_

cv_dic = cross_validate(clf, X_iris, y_iris, cv=outer_cv, scoring=['accuracy'], return_estimator=False, return_train_score=True)
mean_val_score = cv_dic['test_accuracy'].mean()

print('nested_train_scores: ', cv_dic['train_accuracy'])
print('nested_val_scores:   ', cv_dic['test_accuracy'])
print('mean score:            {0:.2f}'.format(mean_val_score))

cross_validate splits the data set in each fold into a training and a test set. In each fold, the input estimator is then trained based on the training set associated with the fold. The inputted estimator here is clf, a parameterized GridSearchCV estimator, i.e. an estimator that cross-validates itself again.

I have three questions about the whole thing:

  1. If clf is used as the estimator for cross_validate, does it (in the course of the GridSearchCV cross validation) split the above mentioned training set into a subtraining set and a validation set in order to determine the best hyper parameter combination?
  2. Out of all models tested via GridSearchCV, does cross_validate validate only the model stored in the best_estimator_ attribute?
  3. Does cross_validate train a model at all (if so, why?) or is the model stored in best_estimator_ validated directly via the test set?

To make it clearer how the questions are meant, here is an illustration of how I imagine the double cross validation at the moment.

enter image description here

like image 399
zwithouta Avatar asked Mar 06 '19 18:03

zwithouta


People also ask

How does cross-validation work in GridSearchCV?

Cross-Validation and GridSearchCV Cross-Validation is used while training the model. As we know that before training the model with data, we divide the data into two parts – train data and test data. In cross-validation, the process divides the train data further into two parts – the train data and the validation data.

What is a drawback with using nested cross-validation?

The problem is that if this score alone is used to then select a model, or the same dataset is used to evaluate the tuned models, then the selection process will be biased by this inadvertent overfitting. The result is an overly optimistic estimate of model performance that does not generalize to new data.

What is nested cross-validation?

Nested cross-validation (CV) is often used to train a model in which hyperparameters also need to be optimized. Nested CV estimates the generalization error of the underlying model and its (hyper)parameter search.

Does grid search do cross-validation?

Grid Search CV:Scikit-Learn library comes with grid search cross-validation implementation. Grid Search CV tries all combinations of parameters grid for a model and returns with the best set of parameters having the best performance score.


Video Answer


1 Answers

If clf is used as the estimator for cross_validate, does it split the above mentioned training set into a subtraining set and a validation set in order to determine the best hyper parameter combination?

Yes as you can see here at Line 230 the training set is again split into a subtraining and validation set (Specifically at line 240).

Update Yes, when you will pass the GridSearchCV classifier into cross-validate it will again split the training set into a test and train set. Here is a link describing this in more detail. Your diagram and assumption is correct.

Out of all models tested via GridSearchCV, does cross_validate train & validate only the model stored in the variable best_estimator?

Yes, as you can see from the answers here and here, the GridSearchCV returns the best_estimator in your case(since refit parameter is True by default in your case.) However, this best estimator will has to be trained again

Does cross_validate train a model at all (if so, why?) or is the model stored in best_estimator_ validated directly via the test set?

As per your third and final question, Yes, it trains an estimator and returns it if return_estimator is set to True. See this line. Which makes sense, since how else is it supposed to return the scores without training an estimator in the first place ?

Update The reason the model is trained again is because the default use case for cross-validate does not assume that you give in the best classfier with the optimum parameters. In this case specifically, you are sending in a classifier from the GridSearchCV but if you send any untrained classifier it is supposed to be trained. What I mean to say here is that, yes, in your case it shouldn't train it again since you are already doing cross-validation using GridSearchCV and using the best estimator. However, there is no way for cross-validate to know this, hence, it assumes that you are sending in an un-optimized or rather untrained estimator, thus it has to train it again and return the scores for the same.

like image 194
Gambit1614 Avatar answered Oct 02 '22 03:10

Gambit1614