I am having some difficulty understanding how the trees are structured in R's gbm gradient boosted machine package. Specifically, looking at the output of the pretty.gbm.tree
Which features do the indices in SplitVar
point to?
I trained a GBM on a dataset, here is the top ~quarter of one of my trees -- the result of a call to pretty.gbm.tree
:
SplitVar SplitCodePred LeftNode RightNode MissingNode ErrorReduction Weight Prediction
0 9 6.250000e+01 1 2 21 0.6634681 5981 0.005000061
1 -1 1.895699e-12 -1 -1 -1 0.0000000 3013 0.018956988
2 31 4.462500e+02 3 4 20 1.0083722 2968 -0.009168477
3 -1 1.388483e-22 -1 -1 -1 0.0000000 1430 0.013884830
4 38 5.500000e+00 5 18 19 1.5748155 1538 -0.030602956
5 24 7.530000e+03 6 13 17 2.8329899 361 -0.078738904
6 41 2.750000e+01 7 11 12 2.2499063 334 -0.064752766
7 28 -3.155000e+02 8 9 10 1.5516610 57 -0.243675567
8 -1 -3.379312e-11 -1 -1 -1 0.0000000 45 -0.337931219
9 -1 1.922333e-10 -1 -1 -1 0.0000000 12 0.109783128
```
It looks to me here that the indices are 0 based, from looking at how LeftNode, RightNode
, and MissingNode
point to different rows. When testing this out by using data samples and following it down the tree to their prediction, I get the correct answer when I consider SplitVar
to be using 1 based indexing.
However, 1 of the many trees I build has a zero in the SplitVar
column! Here is this tree:
SplitVar SplitCodePred LeftNode RightNode MissingNode ErrorReduction Weight Prediction
0 4 1.462500e+02 1 2 21 0.41887 5981 0.0021651262
1 -1 4.117688e-22 -1 -1 -1 0.00000 512 0.0411768781
2 4 1.472500e+02 3 4 20 1.05222 5469 -0.0014870985
3 -1 -2.062798e-11 -1 -1 -1 0.00000 23 -0.2062797579
4 0 4.750000e+00 5 6 19 0.65424 5446 -0.0006222011
5 -1 3.564879e-23 -1 -1 -1 0.00000 4897 0.0035648788
6 28 -3.195000e+02 7 11 18 1.39452 549 -0.0379703437
What is the correct way to view the indexing used by gbm's trees?
The first column that is printed when you use the pretty.gbm.tree
is the row.names
that is assigned in the script pretty.gbm.tree.R
. In the script, the row.names
is assigned as row.names(temp) <- 0:(nrow(temp)-1)
where temp
is the tree information stored in data.frame
form. The right way to interpret the row.names
is to read it as the node_id
with the root node being assigned a 0 value.
In your example:
Id SplitVar SplitCodePred LeftNode RightNode MissingNode ErrorReduction Weight Prediction
0 9 6.250000e+01 1 2 21 0.6634681 5981 0.005000061
means that the root node (indicated by the row number 0) is split by the 9-th split variable (the numbering of the split variable here starts from 0, so the split variable is the 10th column in the training set x
). SplitCodePred
of 6.25
denotes that all points less than 6.25
went to the LeftNode 1
and all points greater than 6.25
went to RightNode 2
. All points that had a missing value in this column were assigned to the MissingNode 21
. The ErrorReduction
was 0.6634
due to this split and there were 5981 (Weight
) in the root node. Prediction
of 0.005
denotes the value assigned to all values at this node before the point was split. In the case of terminal nodes (or leaves) denoted by -1
in SplitVar
, LeftNode
, RightNode
, and MissingNode
, the Prediction
denotes the value predicted for all the points belonging to this leaf node adjusted (times) times the shrinkage
.
To understand the tree structure, its important to note that the splitting of the tree happens in a depth first fashion. So when the root node (with node id 0) is split into its left node and right node, the left side is processed until no further splits are possible before returning and labeling the right node. In both the trees in your example, the RightNode
gets a value of 2. This is because in both cases, the LeftNode
turns out to be a leaf node.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With