I'm trying to refactor a trained spark tree-based model (RandomForest or GBT classifiers) in such a way it can be exported in environments without spark. The toDebugString
method is a good starting point. However, in the case of RandomForestClassifier
, the string just shows the predicted class for each tree, without the relative probabilities. So, if you average the prediction for all the trees, you get a wrong result.
An example. We have a DecisionTree
represented in this way:
DecisionTreeClassificationModel (uid=dtc_884dc2111789) of depth 2 with 5 nodes
If (feature 21 in {1.0})
Predict: 0.0
Else (feature 21 not in {1.0})
If (feature 10 in {0.0})
Predict: 0.0
Else (feature 10 not in {0.0})
Predict: 1.0
As we can see, following the nodes, it looks like the prediction is always either 0 or 1. However, if I apply this single tree to a vector of features, I get probabilities like [0.1007, 0.8993]
, and they make perfect sense, since in the training set the proportion of negative/positive which end up in the same leaf as the example vector matches with the output probabilities.
My questions: where these probabilities are stored? Is there a way to extract them? If so, how? A pyspark
solution would be better.
I'm trying to refactor a trained spark tree-based model (RandomForest or GBT classifiers) in such a way it can be exported in environments without spark. The
Given growing number of tools designed for real-time serving of Spark (and other) models, that's probably reinventing the wheel.
However if you want to access model internals from plain Python it is best to load its serialized form.
Let's say you have:
from pyspark.ml.classification import RandomForestClassificationModel
rf_model: RandomForestClassificationModel
path: str # Absolute path
And you save the model:
rf_model.write().save(path)
You can load it back using Parquet reader that supports mixes of struct and list types. Model writer writes both node data:
node_data = spark.read.parquet("{}/data".format(path))
node_data.printSchema()
root
|-- treeID: integer (nullable = true)
|-- nodeData: struct (nullable = true)
| |-- id: integer (nullable = true)
| |-- prediction: double (nullable = true)
| |-- impurity: double (nullable = true)
| |-- impurityStats: array (nullable = true)
| | |-- element: double (containsNull = true)
| |-- rawCount: long (nullable = true)
| |-- gain: double (nullable = true)
| |-- leftChild: integer (nullable = true)
| |-- rightChild: integer (nullable = true)
| |-- split: struct (nullable = true)
| | |-- featureIndex: integer (nullable = true)
| | |-- leftCategoriesOrThreshold: array (nullable = true)
| | | |-- element: double (containsNull = true)
| | |-- numCategories: integer (nullable = true)
and tree metadata:
tree_meta = spark.read.parquet("{}/treesMetadata".format(path))
tree_meta.printSchema()
root
|-- treeID: integer (nullable = true)
|-- metadata: string (nullable = true)
|-- weights: double (nullable = true)
where the former one provides all the information you need, as the prediction process is basically an aggregation of impurtityStats
*.
You could also access this data directly using underlying Java objects
from collections import namedtuple
import numpy as np
LeafNode = namedtuple("LeafNode", ("prediction", "impurity"))
InternalNode = namedtuple(
"InternalNode", ("left", "right", "prediction", "impurity", "split"))
CategoricalSplit = namedtuple("CategoricalSplit", ("feature_index", "categories"))
ContinuousSplit = namedtuple("ContinuousSplit", ("feature_index", "threshold"))
def jtree_to_python(jtree):
def jsplit_to_python(jsplit):
if jsplit.getClass().toString().endswith(".ContinuousSplit"):
return ContinuousSplit(jsplit.featureIndex(), jsplit.threshold())
else:
jcat = jsplit.toOld().categories()
return CategoricalSplit(
jsplit.featureIndex(),
[jcat.apply(i) for i in range(jcat.length())])
def jnode_to_python(jnode):
prediction = jnode.prediction()
stats = np.array(list(jnode.impurityStats().stats()))
if jnode.numDescendants() != 0: # InternalNode
left = jnode_to_python(jnode.leftChild())
right = jnode_to_python(jnode.rightChild())
split = jsplit_to_python(jnode.split())
return InternalNode(left, right, prediction, stats, split)
else:
return LeafNode(prediction, stats)
return jnode_to_python(jtree.rootNode())
which can be applied to RandomForestModel
like this:
nodes = [jtree_to_python(t) for t in rf_model._java_obj.trees()]
Furthermore such structure can be easily used to make predictions, for both individual trees (warning: Python 3.7+ ahead. For legacy usage please refer to functools
documentation):
from functools import singledispatch
@singledispatch
def should_go_left(split, vector): pass
@should_go_left.register
def _(split: CategoricalSplit, vector):
return vector[split.feature_index] in split.categories
@should_go_left.register
def _(split: ContinuousSplit, vector):
return vector[split.feature_index] <= split.threshold
@singledispatch
def predict(node, vector): pass
@predict.register
def _(node: LeafNode, vector):
return node.prediction, node.impurity
@predict.register
def _(node: InternalNode, vector):
return predict(
node.left if should_go_left(node.split, vector) else node.right,
vector
)
and forests:
from typing import Iterable, Union
def predict_probability(nodes: Iterable[Union[InternalNode, LeafNode]], vector):
total = np.array([
v / v.sum() for _, v in (
predict(node, vector) for node in nodes
)
]).sum(axis=0)
return total / total.sum()
That however depends on the internal API (and weakness of Scala package-scoped access modifiers) and might break in the future.
* DataFrame
as loaded from data
path can be easily transformed to a structure compatible with predict
and predict_probability
functions defined above.
from pyspark.sql.dataframe import DataFrame
from itertools import groupby
from operator import itemgetter
def model_data_to_tree(tree_data: DataFrame):
def dict_to_tree(node_id, nodes):
node = nodes[node_id]
prediction = node.prediction
impurity = np.array(node.impurityStats)
if node.leftChild == -1 and node.rightChild == -1:
return LeafNode(prediction, impurity)
else:
left = dict_to_tree(node.leftChild, nodes)
right = dict_to_tree(node.rightChild, nodes)
feature_index = node.split.featureIndex
left_value = node.split.leftCategoriesOrThreshold
split = (
CategoricalSplit(feature_index, left_value)
if node.split.numCategories != -1
else ContinuousSplit(feature_index, left_value[0])
)
return InternalNode(left, right, prediction, impurity, split)
tree_id = itemgetter("treeID")
rows = tree_data.collect()
return ([
dict_to_tree(0, {node.nodeData.id: node.nodeData for node in nodes})
for tree, nodes in groupby(sorted(rows, key=tree_id), key=tree_id)
] if "treeID" in tree_data.columns
else [dict_to_tree(0, {node.id: node for node in rows})])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With