Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

decision tree repeating class names

I have a very simple sample of data/labels, the problem I'm having is that the decision tree generated (pdf) is repeating the class name:

from sklearn import tree
from sklearn.externals.six import StringIO  
import pydotplus

features_names = ['weight', 'texture']
features = [[140, 1], [130, 1], [150, 0], [110, 0]]
labels = ['apple', 'apple', 'orange', 'orange']

clf = tree.DecisionTreeClassifier()
clf.fit(features, labels)

dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data, 
                         feature_names=features_names,  
                         class_names=labels,  
                         filled=True, rounded=True,  
                         special_characters=True,
                         impurity=False)

graph = pydotplus.graph_from_dot_data(dot_data.getvalue()) 
graph.write_pdf("apples_oranges.pdf")

The resulting pdf looks like:

enter image description here

So, the problem is pretty obvious, it's apple for both possibilities. What am I doing wrong?

From the DOCS:

list of strings, bool or None, optional (default=None)
Names of each of the target classes in ascending numerical order. Only relevant for classification and not supported for multi-output. If True, shows a symbolic representation of the class name.

"...ascending numerical order" this doesn't make much sense for me, if I change the kwarg to:

class_names=sorted(labels)

The result is the same (obvious in this case).

like image 417
Hula Hula Avatar asked Dec 11 '25 18:12

Hula Hula


1 Answers

The class names are literally just that, the name of the classes. It's not the labels for each example.

So one class is 'apple' and the other is 'orange', so you just need to pass in ['apple', 'orange'].

Regarding order, to get it properly consistent, you could use a LabelEncoder to convert your target to an integer int_labels = labelEncoder.fit_transform(labels), use int_labels to fit your decision tree, then use the labelEncoder.classes_ attribute to pass into your graph viz.

like image 76
Ken Syme Avatar answered Dec 14 '25 07:12

Ken Syme



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!