Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to determine the range of the feature values used to predict a class using Random Forest Classifier

I am trying to find out what the range of the values of the features are which the random forest classifier uses to predict a certain class.

For example we have the IRIS dataset;

I use the random forest classifier to predict which class of flower species a flower belongs to based on its features, there are 4 features (sepal length, sepal width, petal length, petal width).

I can find out the feature importance and I can visualize the steps the classifier takes by using Graphviz. Now I would like to find out what the range is of for example, sepal length which will result in the prediction being a Setosa, i.e petal length between 0.2 and 0.4 is an indicator that the species is Setosa. I can view this data visually using Graphviz, but I would like a way to store it and analyse it for my entire dataset using 200 estimators. Is there a way to collect and store the data textually, in the image below of the decision tree; if petal-length <= 2.6 then the class is Setosa.

https://images.app.goo.gl/pPK1KsXAMY3z27JW8

I would like to have a data frame similar to this:

        node |   feature     | Samples | Value     | Class 
   --------------------------------------------------------------
   1.    1   |  sepal-length |    23   |  <= 0.2   | Setosa
   2.    3   |  petal-width  |    45   |  <= 0.3   | Versicolor
   3.   ...          ...         ...        ...        ...
   n.    178 |  sepal-width  |    3    |  <= 0.4   | Setosa

Once I have a data frame I would then be able to analyse and see that for example; A setosa flower will have petal-length between 0.1 - 0.3, sepal-length between 0.4-0.7 etc

Would this even be possible? If so, any ideas would be greatly appreciated.

EDIT: I have looked at the decision path for each tree, and while helpful, they do not contain the predicted class, therefore does not help me with what I am trying to do.

I think my only option is to just parse the dot file which I get from the Graphviz function and manually store the information into a data frame.

like image 286
codiearcher Avatar asked Nov 07 '22 18:11

codiearcher


1 Answers

The RandomForestClassifier of scikit-learn has an attribute called estimators_ which after training is a list of DecisionTreeClassifier instances that together make up the forest.

Now that we can access the individual trees, we take a closer look at the DecisionTreeClassifier instances. Every instance has a tree_ attribute, which contains the actual decision tree and all of the properties you are interested in.

The awesome people of scikit-learn even wrote documentation on how to access the tree's properties.

I understand you are in need of the majority class per node, as in the graphviz visualization, which isn't a standard attribute on the nodes. You mentioned that you might parse the output of graphviz, but maybe you can take a look at the graphviz code!

As you can see at this line in the node_to_str function they take the argmax over the value variable which is defined here. I think that if you combine that with the documentation linked above, you should be able to get the class per node!

like image 55
zeebonk Avatar answered Nov 14 '22 08:11

zeebonk