Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using sample_weight in GridSearchCV

Is it possible to perform a GridSearchCV (to get the best SVM's C) and yet specify the sample_weight with scikit-learn?

Here's my code and the error I'm confronted to:

gs = GridSearchCV(
    svm.SVC(C=1),
    [{
        'kernel': ['linear'],
        'C': [.1, 1, 10],
        'probability': [True],
        'sample_weight': sw_train,
    }]
)

gs.fit(Xtrain, ytrain)

>> ValueError: Invalid parameter sample_weight for estimator SVC


Edit: I solved the issue by getting the latest scikit-learn version and using the following:

gs.fit(Xtrain, ytrain, fit_params={'sample_weight': sw_train})
like image 856
user1771485 Avatar asked Oct 24 '12 14:10

user1771485


People also ask

How does cross-validation work in GridSearchCV?

Cross-Validation and GridSearchCV Cross-Validation is used while training the model. As we know that before training the model with data, we divide the data into two parts – train data and test data. In cross-validation, the process divides the train data further into two parts – the train data and the validation data.

What is Param grid in GridSearchCV?

param_grid: dictionary that contains all of the parameters to try. scoring: evaluation metric to use when ranking results. cv: cross-validation, the number of cv folds for each combination of parameters.

What is GridSearchCV Best_score_?

The grid. best_score_ is the average of all cv folds for a single combination of the parameters you specify in the tuned_params .

What is cv value in GridSearchCV?

cv: number of cross-validation you have to try for each selected set of hyperparameters. verbose: you can set it to 1 to get the detailed print out while you fit the data to GridSearchCV. n_jobs: number of processes you wish to run in parallel for this task if it -1 it will use all available processors.


2 Answers

Just trying to close out this long hanging question...

You needed to get the last version of SKL and use the following:

gs.fit(Xtrain, ytrain, fit_params={'sample_weight': sw_train})

However, it is more in line with the documentation to pass fit_params to the constructor:

gs = GridSearchCV(svm.SVC(C=1), [{'kernel': ['linear'], 'C': [.1, 1, 10], 'probability': [True], 'sample_weight': sw_train}], fit_params={'sample_weight': sw_train})

gs.fit(Xtrain, ytrain)
like image 176
AN6U5 Avatar answered Oct 19 '22 04:10

AN6U5


The previous answers are now obsolete. The dictionary fit_params should be passed to the fit method.

From the documentation for GridSearchCV:

fit_params : dict, optional

Parameters to pass to the fit method.

Deprecated since version 0.19: fit_params as a constructor argument was deprecated in version 0.19 and will be removed in version 0.21. Pass fit parameters to the fit method instead.

like image 39
Sycorax Avatar answered Oct 19 '22 05:10

Sycorax