Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Prune unnecessary leaves in sklearn DecisionTreeClassifier

I use sklearn.tree.DecisionTreeClassifier to build a decision tree. With the optimal parameter settings, I get a tree that has unnecessary leaves (see example picture below - I do not need probabilities, so the leaf nodes marked with red are a unnecessary split)

Tree

Is there any third-party library for pruning these unnecessary nodes? Or a code snippet? I could write one, but I can't really imagine that I am the first person with this problem...

Code to replicate:

from sklearn.tree import DecisionTreeClassifier
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data
y = iris.target
mdl = DecisionTreeClassifier(max_leaf_nodes=8)
mdl.fit(X,y)

PS: I have tried multiple keyword searches and am kind of surprised to find nothing - is there really no post-pruning in general in sklearn?

PPS: In response to the possible duplicate: While the suggested question might help me when coding the pruning algorithm myself, it answers a different question - I want to get rid of leaves that do not change the final decision, while the other question wants a minimum threshold for splitting nodes.

PPPS: The tree shown is an example to show my problem. I am aware of the fact that the parameter settings to create the tree are suboptimal. I am not asking about optimizing this specific tree, I need to do post-pruning to get rid of leaves that might be helpful if one needs class probabilities, but are not helpful if one is only interested in the most likely class.

like image 478
Thomas Avatar asked Jul 18 '18 08:07

Thomas


People also ask

Does pruning reduce overfitting?

Pruning reduces the complexity of the final classifier, and hence improves predictive accuracy by the reduction of overfitting.

What is weakest link pruning?

Essentially, pruning recursively finds the node with the “weakest link.” The weakest link is characterized by an effective alpha, where the nodes with the smallest effective alpha are pruned first. Mathematically, the cost complexity measure for a tree T is given by: R(T) — Total training error of leaf nodes.

How do you post pruning in a decision tree?

The post-pruning approach eliminates branches from a “completely grown” tree. A tree node is pruned by eliminating its branches. The price complexity pruning algorithm is an instance of the post-pruning approach. The pruned node turns into a leaf and is labeled by the most common class between its previous branches.


2 Answers

Using ncfirth's link, I was able to modify the code there so that it fits to my problem:

from sklearn.tree._tree import TREE_LEAF

def is_leaf(inner_tree, index):
    # Check whether node is leaf node
    return (inner_tree.children_left[index] == TREE_LEAF and 
            inner_tree.children_right[index] == TREE_LEAF)

def prune_index(inner_tree, decisions, index=0):
    # Start pruning from the bottom - if we start from the top, we might miss
    # nodes that become leaves during pruning.
    # Do not use this directly - use prune_duplicate_leaves instead.
    if not is_leaf(inner_tree, inner_tree.children_left[index]):
        prune_index(inner_tree, decisions, inner_tree.children_left[index])
    if not is_leaf(inner_tree, inner_tree.children_right[index]):
        prune_index(inner_tree, decisions, inner_tree.children_right[index])

    # Prune children if both children are leaves now and make the same decision:     
    if (is_leaf(inner_tree, inner_tree.children_left[index]) and
        is_leaf(inner_tree, inner_tree.children_right[index]) and
        (decisions[index] == decisions[inner_tree.children_left[index]]) and 
        (decisions[index] == decisions[inner_tree.children_right[index]])):
        # turn node into a leaf by "unlinking" its children
        inner_tree.children_left[index] = TREE_LEAF
        inner_tree.children_right[index] = TREE_LEAF
        ##print("Pruned {}".format(index))

def prune_duplicate_leaves(mdl):
    # Remove leaves if both 
    decisions = mdl.tree_.value.argmax(axis=2).flatten().tolist() # Decision for each node
    prune_index(mdl.tree_, decisions)

Using this on a DecisionTreeClassifier clf:

prune_duplicate_leaves(clf)

Edit: Fixed a bug for more complex trees

like image 134
Thomas Avatar answered Oct 27 '22 14:10

Thomas


DecisionTreeClassifier(max_leaf_nodes=8) specifies (max) 8 leaves, so unless the tree builder has another reason to stop it will hit the max.

In the example shown, 5 of the 8 leaves have a very small amount of samples (<=3) compared to the others 3 leaves (>50), a possible sign of over-fitting. Instead of pruning the tree after training, one can specifying either min_samples_leaf or min_samples_split to better guide the training, which will likely get rid of the problematic leaves. For instance use the value 0.05 for least 5% of samples.

like image 30
Jon Nordby Avatar answered Oct 27 '22 16:10

Jon Nordby