Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Random Forest Classifier decision path method (scikit)

I've implemented a standard randomforestclassifier on the titanic dataset, and hope to explore sklearn's decision_path method which was introduced in v0.18. (http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html)

However, it outputs a sparse matrix which I'm not certain how to make sense of. Can anyone advise on how best to visualise this?

#Training a simplified random forest
estimator = RandomForestClassifier(random_state=0, n_estimators=3, max_depth=3)
estimator.fit(X_train, y_train)

#Extracting the decision path for instance i = 12
i_data = X_test.iloc[12].values.reshape(1,-1)
d_path = rf_best.decision_path(i_data)

print(d_path)

Output:

(<1x3982 sparse matrix of type '' with 598 stored elements in Compressed Sparse Row format>, array([ 0, 45,
98, 149, 190, 233, 258, 309, 360, 401, 430, 461, 512, 541, 580, 623, 668, 711, 760, 803, 852, 889, 932, 981, 1006, 1035, 1074, 1107, 1136, 1165, 1196, 1241, 1262, 1313, 1350, 1385, 1420, 1465, 1518, 1553, 1590, 1625, 1672, 1707, 1744, 1787, 1812, 1863, 1904, 1945, 1982, 2017, 2054, 2097, 2142, 2191, 2228, 2267, 2304, 2343, 2390, 2419, 2456, 2489, 2534, 2583, 2632, 2677, 2714, 2739, 2786, 2833, 2886, 2919, 2960, 2995, 3032, 3073, 3126, 3157, 3194, 3239, 3274, 3313, 3354, 3409, 3458, 3483, 3516, 3539, 3590, 3629, 3660, 3707, 3750, 3777, 3822, 3861, 3898, 3939, 3982], dtype=int32))

Apologies if I'm not providing enough detail - do let me know otherwise.

Thanks!

Note: Edited to simplify random forest (limit depth and n_trees)

like image 974
jc023 Avatar asked Nov 09 '22 00:11

jc023


1 Answers

If you would like to visualize the trees in the forest you could try the answer provided here: https://stats.stackexchange.com/q/118016

Adapting to your problem:

from sklearn import tree

...

i_tree = 0
for tree_in_forest in estimator.estimators_:
    with open('tree_' + str(i_tree) + '.dot', 'w') as my_file:
        my_file = tree.export_graphviz(tree_in_forest, out_file = my_file)
    i_tree = i_tree + 1

This will create 10 (default number of trees in the forest) files called tree_i.dot for i = 0 to 9. You can create pdf files for each one of them doing at the terminal (for example):

$ dot -Tpdf tree_0.dot -o tree.pdf

Probably there's a smarter way to do it, I'd be happy to learn it if anyone could help :)

like image 127
Luciana Abud Avatar answered Nov 14 '22 21:11

Luciana Abud