Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to extract decision rules (features splits) from xgboost model in python3?

Tags:

python

xgboost

I need to extract the decision rules from my fitted xgboost model in python. I use 0.6a2 version of xgboost library and my python version is 3.5.2.

My ultimate goal is to use those splits to bin variables ( according to the splits).

I did not come across any property of the model for this version which can give me splits.

plot_tree is giving me something similar. However it is visualization of the tree.

I need something like https://stackoverflow.com/a/39772170/4559070 for xgboost model

like image 932
Artiga Avatar asked May 04 '18 13:05

Artiga


2 Answers

You can find the decision rules as a dataframe through the function model._Booster.trees_to_dataframe(). The Yes column contains the ID of the yes-branch, and the No column of the no-branch. This way you can reconstruct the tree, since for each row of the dataframe, the node ID has directed edges to Yes and No. You can do that with networkx like so:

import networkx as nx

df = model._Booster.trees_to_dataframe()

# Create graph
G = nx.Graph()
# Add all the nodes
G.add_nodes_from(df.ID.tolist())
# Add the edges. This should be simpler in Pandas, but there seems to be a bug with df.apply(tuple, axis=1) at the moment.
yes_pairs = df[['ID', 'Yes']].dropna()
no_pairs = df[['ID', 'No']].dropna()
yes_edges = [tuple([i[0], i[1]]) for i in yes_pairs.values]
no_edges = [tuple([i[0], i[1]]) for i in no_pairs.values]
G.add_edges_from(yes_edges + no_edges)
like image 171
herman Avatar answered Oct 29 '22 03:10

herman


Below is the code fragment which prints all the rules extracted from the booster trees from xgboost model.

The below code is based on the missing values substitution with 999999

import networkx as nx

df = model._Booster.trees_to_dataframe()

G = nx.DiGraph()
G.add_nodes_from(df.ID.tolist())

yes_edges = df[['ID', 'Yes', 'Feature', 'Split']].dropna()
yes_edges['label'] = yes_edges.apply(lambda x: "({feature} > {value:.2f} or {feature} = 999999)".format(feature=x['Feature'], value=x['Split']), axis=1)

no_edges = df[['ID', 'No', 'Feature', 'Split']].dropna()
no_edges['label'] = no_edges.apply(lambda x: "({feature} < {value:.2f})".format(feature=x['Feature'], value=x['Split']), axis=1)

for v in yes_edges.values:
    G.add_edge(v[0],v[1], feature=v[2], expr=v[4])
    
for v in no_edges.values:
    G.add_edge(v[0],v[1], feature=v[2], expr=v[4])

leaf_node_values = {i[0]:i[1] for i in df[df.Feature=='Leaf'][['ID','Gain']].values}    
    
roots = []
leaves = []
for node in G.nodes :
    if G.in_degree(node) == 0 : # it's a root
        roots.append(node)
    elif G.out_degree(node) == 0 : # it's a leaf
        leaves.append(node)
        
paths = []
for root in roots :
    for leaf in leaves :
        for path in nx.all_simple_paths(G, root, leaf) :
            paths.append(path)
    
pred_conditions = []
for path in paths:
    parts = []
    for i in range(len(path)-1):
        parts.append(G[path[i]][path[i+1]]['expr'])
    pred_conditions.append("if " + " and ".join(parts) + " then {value:.4f}".format(value=leaf_node_values.get(path[-1])))

The above code prints every rule in the format as below:

if x>y and a>b and c<d then e
like image 32
Rakesh Chintha Avatar answered Oct 29 '22 02:10

Rakesh Chintha