Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is there a way to see the folds for cross-validation in GridSearchCV?

I'm currently doing a 3-fold cv using GridSearchCV in Python to optimize hyperparameters. I'm just wondering if there is any way to see the indices of training and testing data in the cv used in GridSearchCV?

like image 260
Frederica Avatar asked Feb 02 '17 20:02

Frederica


1 Answers

You can if you don't want to shuffle the samples before folding during the CV-stage. You can pass an instance of KFold (or another CV-class) to the GridSearchCV constructor and access it's folds like this:

import pandas as pd
import numpy as np
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold

params = {'penalty' : ['l1', 'l2'], 'C' : [1,2,3]}
grid = GridSearchCV(LogisticRegression(), params, cv=KFold(n_splits=3))

X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [5, 6], [7, 8]])

for train, test in grid.cv.split(X):
    print('TRAIN: ', train, ' TEST: ', test)

which prints:

TRAIN:  [2 3 4 5]  TEST:  [0 1]
TRAIN:  [0 1 4 5]  TEST:  [2 3]
TRAIN:  [0 1 2 3]  TEST:  [4 5]

For non-shuffled CV, the folds are always the same, so you can be sure that these are the folds that get used during the grid-search.

If you want to shuffle the samples prior to folding, it is a little more complicated, because every call to cv.split() generates a different split. I can think of two ways:

  1. You can provide the CV-object with a fixed random_state, e.g. KFold(n_splits=3, shuffle=True, random_state=42).

  2. Before creating the GridSearchCV object, create a list from the KFold iterator.

So, for the second approach, do:

grid = GridSearchCV(LogisticRegression(), params, 
                    cv=list(KFold(n_splits=3, shuffle=True).split(X)))

Other than an iterator, a list is a fixed object and unless you manipulate it manually, it will keep the same values over all GridSearch iterations.

like image 74
Toterich Avatar answered Nov 15 '22 12:11

Toterich