Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Possible to modify/prune learned trees in scikit-learn?

It is possible to access tree parameters in sklearn with

tree.tree_.children_left
tree.tree_.children_right
tree.tree_.threshold
tree.tree_.feature

and so on

However, trying to write to these variables raises a not writable exception

Is there any way to modify the learned tree, or bypass the AttributeError not writable?

like image 269
Eric Thibodeau Laufer Avatar asked Aug 17 '16 16:08

Eric Thibodeau Laufer


People also ask

Does Sklearn decision tree prune?

We will use DecisionTreeClassifier from sklearn. tree for this purpose. By default, the Decision Tree function doesn't perform any pruning and allows the tree to grow as much as it can. We get an accuracy score of 0.95 and 0.63 on the train and test part respectively as shown below.

Does pruning increase accuracy?

Pruning can not only significantly reduce the size but also improve the classification accuracy of unseen objects.

Is tree pruning useful in decision tree induction?

Pruning reduces the size of decision trees by removing parts of the tree that do not provide power to classify instances. Decision trees are the most susceptible out of all the machine learning algorithms to overfitting and effective pruning can reduce this likelihood.

Why do we need to perform pruning for the decision tree model?

The decision tree's overfitting problem is caused by other factors as well as synch as branches sometimes are impacted by noise and outliers of data. Pruning is a critical step in constructing tree based machine learning models that help overcome these issues.


1 Answers

The attributes are both arrays of int that can not be overwritten. You can still modify the elements of these arrays. That will not lighten the data.

children_left : array of int, shape [node_count]
    children_left[i] holds the node id of the left child of node i.
    For leaves, children_left[i] == TREE_LEAF. Otherwise,
    children_left[i] > i. This child handles the case where
    X[:, feature[i]] <= threshold[i].

children_right : array of int, shape [node_count]
    children_right[i] holds the node id of the right child of node i.
    For leaves, children_right[i] == TREE_LEAF. Otherwise,
    children_right[i] > i. This child handles the case where
    X[:, feature[i]] > threshold[i].

feature : array of int, shape [node_count]
    feature[i] holds the feature to split on, for the internal node i.

threshold : array of double, shape [node_count]
    threshold[i] holds the threshold for the internal node i.

To prune a DecisionTree by the number of observations in nodes, I use this function. You need to know that the TREE_LEAF constant is equal to -1.

def prune(decisiontree, min_samples_leaf = 1):
    if decisiontree.min_samples_leaf >= min_samples_leaf:
        raise Exception('Tree already more pruned')
    else:
        decisiontree.min_samples_leaf = min_samples_leaf
        tree = decisiontree.tree_
        for i in range(tree.node_count):
            n_samples = tree.n_node_samples[i]
            if n_samples <= min_samples_leaf:
                tree.children_left[i]=-1
                tree.children_right[i]=-1

Here is an example which produces graphviz output before and after:

[from sklearn.tree import DecisionTreeRegressor as DTR
from sklearn.datasets import load_diabetes
from sklearn.tree import export_graphviz as export

bunch = load_diabetes()
data = bunch.data
target = bunch.target

dtr = DTR(max_depth = 4)
dtr.fit(data,target)

export(decision_tree=dtr.tree_, out_file='before.dot')
prune(dtr, min_samples_leaf = 100)
export(decision_tree=dtr.tree_, out_file='after.dot')][1]
like image 64
S. Lundy Avatar answered Oct 13 '22 12:10

S. Lundy