Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to graph grid scores from GridSearchCV?

I am looking for a way to graph grid_scores_ from GridSearchCV in sklearn. In this example I am trying to grid search for best gamma and C parameters for an SVR algorithm. My code looks as follows:

    C_range = 10.0 ** np.arange(-4, 4)     gamma_range = 10.0 ** np.arange(-4, 4)     param_grid = dict(gamma=gamma_range.tolist(), C=C_range.tolist())     grid = GridSearchCV(SVR(kernel='rbf', gamma=0.1),param_grid, cv=5)     grid.fit(X_train,y_train)     print(grid.grid_scores_) 

After I run the code and print the grid scores I get the following outcome:

[mean: -3.28593, std: 1.69134, params: {'gamma': 0.0001, 'C': 0.0001}, mean: -3.29370, std: 1.69346, params: {'gamma': 0.001, 'C': 0.0001}, mean: -3.28933, std: 1.69104, params: {'gamma': 0.01, 'C': 0.0001}, mean: -3.28925, std: 1.69106, params: {'gamma': 0.1, 'C': 0.0001}, mean: -3.28925, std: 1.69106, params: {'gamma': 1.0, 'C': 0.0001}, mean: -3.28925, std: 1.69106, params: {'gamma': 10.0, 'C': 0.0001},etc]  

I would like to visualize all the scores (mean values) depending on gamma and C parameters. The graph I am trying to obtain should look as follows:

enter image description here

Where x-axis is gamma, y-axis is mean score (root mean square error in this case), and different lines represent different C values.

like image 326
kroonike Avatar asked May 11 '16 11:05

kroonike


People also ask

What is Param grid in GridSearchCV?

param_grid – A dictionary with parameter names as keys and lists of parameter values. 3. scoring – The performance measure. For example, 'r2' for regression models, 'precision' for classification models.

Does GridSearchCV shuffle data?

However, GridSearchCV will use the same shuffling for each set of parameters validated by a single call to its fit method.


2 Answers

The code shown by @sascha is correct. However, the grid_scores_ attribute will be soon deprecated. It is better to use the cv_results attribute.

It can be implemente in a similar fashion to that of @sascha method:

def plot_grid_search(cv_results, grid_param_1, grid_param_2, name_param_1, name_param_2):     # Get Test Scores Mean and std for each grid search     scores_mean = cv_results['mean_test_score']     scores_mean = np.array(scores_mean).reshape(len(grid_param_2),len(grid_param_1))      scores_sd = cv_results['std_test_score']     scores_sd = np.array(scores_sd).reshape(len(grid_param_2),len(grid_param_1))      # Plot Grid search scores     _, ax = plt.subplots(1,1)      # Param1 is the X-axis, Param 2 is represented as a different curve (color line)     for idx, val in enumerate(grid_param_2):         ax.plot(grid_param_1, scores_mean[idx,:], '-o', label= name_param_2 + ': ' + str(val))      ax.set_title("Grid Search Scores", fontsize=20, fontweight='bold')     ax.set_xlabel(name_param_1, fontsize=16)     ax.set_ylabel('CV Average Score', fontsize=16)     ax.legend(loc="best", fontsize=15)     ax.grid('on')  # Calling Method  plot_grid_search(pipe_grid.cv_results_, n_estimators, max_features, 'N Estimators', 'Max Features') 

The above results in the following plot:

enter image description here

like image 71
David Alvarez Avatar answered Sep 30 '22 08:09

David Alvarez


from sklearn.svm import SVC from sklearn.grid_search import GridSearchCV from sklearn import datasets import matplotlib.pyplot as plt import seaborn as sns import numpy as np  digits = datasets.load_digits() X = digits.data y = digits.target  clf_ = SVC(kernel='rbf') Cs = [1, 10, 100, 1000] Gammas = [1e-3, 1e-4] clf = GridSearchCV(clf_,             dict(C=Cs,                  gamma=Gammas),                  cv=2,                  pre_dispatch='1*n_jobs',                  n_jobs=1)  clf.fit(X, y)  scores = [x[1] for x in clf.grid_scores_] scores = np.array(scores).reshape(len(Cs), len(Gammas))  for ind, i in enumerate(Cs):     plt.plot(Gammas, scores[ind], label='C: ' + str(i)) plt.legend() plt.xlabel('Gamma') plt.ylabel('Mean score') plt.show() 
  • Code is based on this.
  • Only puzzling part: will sklearn always respect the order of C & Gamma -> official example uses this "ordering"

Output:

Example plot

like image 28
sascha Avatar answered Sep 30 '22 07:09

sascha