Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

using best params from gridsearchcv

I don't know if it is the right question to ask here, but I will ask anyways. If it is not allowed please do let me know.

I have used GridSearchCV to tune parameters to find best accuracy. This is what I have done:

from sklearn.grid_search import GridSearchCV
parameters = {'min_samples_split':np.arange(2, 80), 'max_depth': np.arange(2,10), 'criterion':['gini', 'entropy']}
clfr = DecisionTreeClassifier()
grid = GridSearchCV(clfr, parameters,scoring='accuracy', cv=8)
grid.fit(X_train,y_train)
print('The parameters combination that would give best accuracy is : ')
print(grid.best_params_)
print('The best accuracy achieved after parameter tuning via grid search is : ', grid.best_score_)

This gives me following result:

The parameters combination that would give best accuracy is : 
{'max_depth': 5, 'criterion': 'entropy', 'min_samples_split': 2}
The best accuracy achieved after parameter tuning via grid search is :  0.8147086914995224

Now, I want to use these parameters while calling a function that visualizes a decision tree

The function looks something like this

def visualize_decision_tree(decision_tree, feature, target):
    dot_data = export_graphviz(decision_tree, out_file=None, 
                         feature_names=feature,  
                         class_names=target,  
                         filled=True, rounded=True,  
                         special_characters=True)  
    graph = pydotplus.graph_from_dot_data(dot_data)  
    return Image(graph.create_png())

Right now I am trying to use the best parameters provided by GridSearchCV to call the function in the following way

dtBestScore = DecisionTreeClassifier(parameters = grid.best_params_)
dtBestScore = dtBestScore.fit(X=dfWithTrainFeatures, y= dfWithTestFeature)
visualize_decision_tree(dtBestScore, list(dfCopy.columns.delete(0).values), 'survived')

I am getting error in first line of code which says

TypeError: __init__() got an unexpected keyword argument 'parameters'

Is there some way I can somehow manage to use the best parameters provided by grid search and use it automatically? Rather than looking the result and manually setting value of each parameter?

like image 891
Cybercop Avatar asked Jan 05 '17 00:01

Cybercop


1 Answers

Try python kwargs:

DecisionTreeClassifier(**grid.best_params)

See http://pythontips.com/2013/08/04/args-and-kwargs-in-python-explaine‌​d for more on kwargs.

like image 173
Oliver Dain Avatar answered Nov 07 '22 14:11

Oliver Dain