Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use the actual feature names instead of "X" in scikit-learn DecisionTreeRegressor?

I suppose this is possible since in the def of fit function it says:

X : array-like, shape = [n_samples, n_features]

Now I have,

enter image description here

I can certainly generate a string representation of the decision tree then replace X[] with actual feature names. But I wonder if the fit function could directly take feature names as part of inputs? I tried the following format for each sample

  • [1, 2, "feature_1", "feature_2"]

  • [[1, 2], ["feature_1", "feature_2"]]

but neither worked. What does that shape mean? Could you please give me an example?

like image 933
shuaiyuancn Avatar asked Jan 29 '14 01:01

shuaiyuancn


1 Answers

The fit function itself doesn't support anything like that. However, you can draw the decision tree, including feature labels, with the export_graphviz member function. (Isn't this how you generated the tree above?). Essentially, you'd do something like this:

iris = load_iris()
t = tree.DecisionTreeClassifier()
fitted_tree = t.fit(iris.data, iris.targets)
outfile = tree.export_graphviz(fitted_tree, out_file='filename.dot', feature_names=iris.feature_names)
outfile.close()

This will produce a 'dot' file, which graphviz (which must be installed separately) can then "render" into a traditional image format (postscript, png, etc.) For example, to make a png file, you'd run:

dot -Tpng filename.dot > filename.png

The dot file itself is a plain-text format and fairly self-explanatory. If you wanted to tweak the text, a simple find-replace in the text editor of your choice would work. There are also python modules for directly interacting with graphviz and its files. PyDot seems to be pretty popular, but there are others too.


The shape reference in fit's documentation just refers to the layout of X, the training data matrix. Specifically, it expects the first index to vary over training examples, while the 2nd index refers to features. For example, suppose your data's shape is (150, 4), as is the case for iris.data. The fit function will interpret it as containing 150 training examples, each of which consists of four values.

like image 187
Matt Krause Avatar answered Sep 29 '22 11:09

Matt Krause