Expanding on a prior question: Changing colors for decision tree plot created using export graphviz
How would I color the nodes of the tree bases on the dominant class (species of iris), instead of a binary distinction? This should require a combination of the iris.target_names, the string describing the class, and iris.target, the class.
import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
import collections
clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
dot_data = tree.export_graphviz(clf, out_file=None,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()
edges = graph.get_edge_list()
colors = ('brown', 'forestgreen')
edges = collections.defaultdict(list)
for edge in graph.get_edge_list():
edges[edge.get_source()].append(int(edge.get_destination()))
for edge in edges:
edges[edge].sort()
for i in range(2):
dest = graph.get_node(str(edges[edge][i]))[0]
dest.set_fillcolor(colors[i])
graph.write_png('tree.png')
The code from the example looks so familiar and is therefore easy to modify :)
For each node Graphviz
tells us how many samples from each group we have, i.e. if it is a mixed population or the tree came to a decision. We can extract this info and use to get a color.
values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]
Alternatively you can map the GraphViz
nodes back to the sklearn
nodes:
values = clf.tree_.value[int(node.get_name())][0]
We only have 3 classes, so each one gets its own color (red, green, blue), mixed populations get mixed colors according to their distribution.
values = [int(255 * v / sum(values)) for v in values]
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])
We can now see the separation nicely, the greener it gets the more of the 2nd class we have, same for blue and the 3rd class.
import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()
clf = clf.fit(iris.data, iris.target)
dot_data = tree.export_graphviz(clf,
feature_names=iris.feature_names,
out_file=None,
filled=True,
rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data)
nodes = graph.get_node_list()
for node in nodes:
if node.get_label():
values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]
values = [int(255 * v / sum(values)) for v in values]
color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2])
node.set_fillcolor(color)
graph.write_png('colored_tree.png')
A general solution for more than 3 classes which colors only the final nodes .
colors = ('lightblue', 'lightyellow', 'forestgreen', 'lightred', 'white')
for node in nodes:
if node.get_name() not in ('node', 'edge'):
values = clf.tree_.value[int(node.get_name())][0]
#color only nodes where only one class is present
if max(values) == sum(values):
node.set_fillcolor(colors[numpy.argmax(values)])
#mixed nodes get the default color
else:
node.set_fillcolor(colors[-1])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With