Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Get feature and class names into decision tree using export graphviz

Good Afternoon,

I am working on a decision tree classifier and am having trouble visualizing it. I can output the decision tree, however I cannot get my feature or class names/labels into it. My data is in a pandas dataframe format which I then move into a numpy array and pass to the classifier. I've tried a few things, but just seem to error out on the export when I try and specify class names. Any help would be appreciated. Code is below.

all_inputs=df.ix[:,14:].values
all_classes=df['wic'].values

(training_inputs,
 testing_inputs,
 training_classes,
 testing_classes) = train_test_split(all_inputs, all_classes,train_size=0.75, random_state=1)

decision_tree_classifier=DecisionTreeClassifier()
decision_tree_classifier.fit(training_inputs,training_classes)

export_graphviz(decision_tree_classifier, out_file="mytree.dot",  
                     feature_names=??,  
                     class_names=??)  

LIke I said, it runs fine and outputs a decision tree viz if I take out the feature_names and class_names parameters. I'd like to include them in the output though if possible and have hit a wall...

Any help would be greatly appreciated!

Thanks,

Scott

like image 604
sokeefe1014 Avatar asked Sep 13 '16 17:09

sokeefe1014


People also ask

Which function is used to export decision tree?

Export a decision tree in DOT format.

How do you identify a feature important in a decision tree?

Feature importance is calculated as the decrease in node impurity weighted by the probability of reaching that node. The node probability can be calculated by the number of samples that reach the node, divided by the total number of samples. The higher the value the more important the feature.

How do you visualize a decision tree in python without graphviz?

The scikit-learn (sklearn) library added a new function that allows us to plot the decision tree without GraphViz. So we can use the plot_tree function with the matplotlib library.


1 Answers

The class names are stored in decision_tree_classifier.classes_, i.e. the classes_ attribute of your DecisionTreeClassifier instance. And the feature names should be the columns of your input dataframe. For your case you will have

class_names = decision_tree_classifier.classes_
feature_names = df.columns[14:]
like image 114
maxymoo Avatar answered Oct 10 '22 09:10

maxymoo