I use sklearn.tree.DecisionTreeClassifier to build a decision tree. With the optimal parameter settings, I get a tree that has unnecessary leaves (see example picture below - I do not need probabilities, so the leaf nodes marked with red are a unnecessary split)
Is there any third-party library for pruning these unnecessary nodes? Or a code snippet? I could write one, but I can't really imagine that I am the first person with this problem...
Code to replicate:
from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data
y = iris.target
mdl = DecisionTreeClassifier(max_leaf_nodes=8)
mdl.fit(X,y)
PS: I have tried multiple keyword searches and am kind of surprised to find nothing - is there really no post-pruning in general in sklearn?
PPS: In response to the possible duplicate: While the suggested question might help me when coding the pruning algorithm myself, it answers a different question - I want to get rid of leaves that do not change the final decision, while the other question wants a minimum threshold for splitting nodes.
PPPS: The tree shown is an example to show my problem. I am aware of the fact that the parameter settings to create the tree are suboptimal. I am not asking about optimizing this specific tree, I need to do post-pruning to get rid of leaves that might be helpful if one needs class probabilities, but are not helpful if one is only interested in the most likely class.
Pruning reduces the complexity of the final classifier, and hence improves predictive accuracy by the reduction of overfitting.
Essentially, pruning recursively finds the node with the “weakest link.” The weakest link is characterized by an effective alpha, where the nodes with the smallest effective alpha are pruned first. Mathematically, the cost complexity measure for a tree T is given by: R(T) — Total training error of leaf nodes.
The post-pruning approach eliminates branches from a “completely grown” tree. A tree node is pruned by eliminating its branches. The price complexity pruning algorithm is an instance of the post-pruning approach. The pruned node turns into a leaf and is labeled by the most common class between its previous branches.
Using ncfirth's link, I was able to modify the code there so that it fits to my problem:
from sklearn.tree._tree import TREE_LEAF
def is_leaf(inner_tree, index):
# Check whether node is leaf node
return (inner_tree.children_left[index] == TREE_LEAF and
inner_tree.children_right[index] == TREE_LEAF)
def prune_index(inner_tree, decisions, index=0):
# Start pruning from the bottom - if we start from the top, we might miss
# nodes that become leaves during pruning.
# Do not use this directly - use prune_duplicate_leaves instead.
if not is_leaf(inner_tree, inner_tree.children_left[index]):
prune_index(inner_tree, decisions, inner_tree.children_left[index])
if not is_leaf(inner_tree, inner_tree.children_right[index]):
prune_index(inner_tree, decisions, inner_tree.children_right[index])
# Prune children if both children are leaves now and make the same decision:
if (is_leaf(inner_tree, inner_tree.children_left[index]) and
is_leaf(inner_tree, inner_tree.children_right[index]) and
(decisions[index] == decisions[inner_tree.children_left[index]]) and
(decisions[index] == decisions[inner_tree.children_right[index]])):
# turn node into a leaf by "unlinking" its children
inner_tree.children_left[index] = TREE_LEAF
inner_tree.children_right[index] = TREE_LEAF
##print("Pruned {}".format(index))
def prune_duplicate_leaves(mdl):
# Remove leaves if both
decisions = mdl.tree_.value.argmax(axis=2).flatten().tolist() # Decision for each node
prune_index(mdl.tree_, decisions)
Using this on a DecisionTreeClassifier clf:
prune_duplicate_leaves(clf)
Edit: Fixed a bug for more complex trees
DecisionTreeClassifier(max_leaf_nodes=8)
specifies (max) 8 leaves, so unless the tree builder has another reason to stop it will hit the max.
In the example shown, 5 of the 8 leaves have a very small amount of samples (<=3) compared to the others 3 leaves (>50), a possible sign of over-fitting.
Instead of pruning the tree after training, one can specifying either min_samples_leaf
or min_samples_split
to better guide the training, which will likely get rid of the problematic leaves. For instance use the value 0.05
for least 5% of samples.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With