Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

confused by apply function of GradientBoostingClassifier

For apply function, you can refer to here

My confusion is more from this sample, and I have added some print to below code snippet to output more debug information,

grd = GradientBoostingClassifier(n_estimators=n_estimator)
grd_enc = OneHotEncoder()
grd_lm = LogisticRegression()
grd.fit(X_train, y_train)
test_var = grd.apply(X_train)[:, :, 0]
print "test_var.shape", test_var.shape
print "test_var", test_var
grd_enc.fit(grd.apply(X_train)[:, :, 0])
grd_lm.fit(grd_enc.transform(grd.apply(X_train_lr)[:, :, 0]), y_train_lr)

The output is like below, and confused what are the numbers like 6., 3. and 10. mean? And how they are related to the final classification result?

test_var.shape (20000, 10)
test_var [[  6.   6.   6. ...,  10.  10.  10.]
 [ 10.  10.  10. ...,   3.   3.   3.]
 [  6.   6.   6. ...,  11.  10.  10.]
 ..., 
 [  6.   6.   6. ...,  10.  10.  10.]
 [  6.   6.   6. ...,  11.  10.  10.]
 [  6.   6.   6. ...,  11.  10.  10.]]
like image 902
Lin Ma Avatar asked May 01 '18 05:05

Lin Ma


1 Answers

To understand gradient boosting, you need first to understand individual trees. I will show a small example.

Here is the setup: a small GB model trained on Iris dataset to predict whether a flower belongs to the class 2.

# import the most common dataset
from sklearn.datasets import load_iris
from sklearn.ensemble import GradientBoostingClassifier
X, y = load_iris(return_X_y=True)
# there are 150 observations and 4 features
print(X.shape) # (150, 4)
# let's build a small model = 5 trees with depth no more than 2
model = GradientBoostingClassifier(n_estimators=5, max_depth=2, learning_rate=1.0)
model.fit(X, y==2) # predict 2nd class vs rest, for simplicity
# we can access individual trees
trees = model.estimators_.ravel()
print(len(trees)) # 5
# there are 150 observations, each is encoded by 5 trees, each tree has 1 output
applied = model.apply(X) 
print(applied.shape) # (150, 5, 1)
print(applied[0].T) # [[2. 2. 2. 5. 2.]] - a single row of the apply() result
print(X[0]) # [5.1 3.5 1.4 0.2] - the pbservation corresponding to that row
print(trees[0].apply(X[[0]])) # [2] - 2 is the result of application the 0'th tree to the sample
print(trees[3].apply(X[[0]])) # [5] - 5 is the result of application the 3'th tree to the sample

You can see that each number in the sequence [2. 2. 2. 5. 2.] produced by model.apply() corresponds to an output of a single tree. But what do these numbers mean?

We can easily analyse decision trees by visual examination. Here is a function to plot one

# a function to draw a tree. You need pydotplus and graphviz installed 
# sudo apt-get install graphviz
# pip install pydotplus

from sklearn.externals.six import StringIO  
from IPython.display import Image  
from sklearn.tree import export_graphviz
import pydotplus
def plot_tree(clf):
    dot_data = StringIO()
    export_graphviz(clf, out_file=dot_data, node_ids=True,
                    filled=True, rounded=True, 
                    special_characters=True)
    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
    return Image(graph.create_png())

# now we can plot the first tree
plot_tree(trees[0])

enter image description here

You can see that each node has a number (from 0 to 6). If we push our single example into this tree, it will first go to node #1 (because the feature x3 has value 0.2 < 1.75), and then to node #2 (because the feature x2 has value 1.4 < 4.95.

In the same way, we can analyze the tree 3 which has produced the output 5:

plot_tree(trees[3])

enter image description here

Here our observation goes first to node #4 and then to node #5, because x1=3.5>2.25 and x2=1.4<4.85. Thus, it ends up with number 5.

It's that simple! Each number produced by apply() is the ordinal number of the node of the corresponding tree in which the sample ends up.

The relation of these numbers to the final classification result is through the value of the leaves in the corresponding trees. In case of binary classification, the value in all leaves just adds up, and if it is positive, then the 'positive' wins, otherwise the 'negative' class. In case of multiclass classification, the values add up for each class, and the class with the largest total value wins.

In our case, the first tree (with its node #2) gives value -1.454, the other trees also give some values, and total sum of them is -4.84. It is negative, thus, our example does not belong to class 2.

values = [trees[i].tree_.value[int(leaf)][0,0] for i, leaf in enumerate(applied[0].ravel())]
print(values) # [-1.454, -1.05, -0.74, -1.016, -0.58] - the values of nodes [2,2,2,5,2] in the corresponding trees
print(sum(values)) # -4.84 - sum of these values is negative -> this is not class 2
like image 190
David Dale Avatar answered Sep 18 '22 17:09

David Dale