I have built a decision tree model in R using rpart
and ctree
.
I also have predicted a new dataset using the built model and got predicted probabilities and classes.
However, I would like to extract the rule/path, in a single string, for every observation (in predicted dataset) has followed. Storing this data in tabular format, I can explain prediction with reason in a automated manner without opening R.
Which means I want to got following.
ObsID Probability PredictedClass PathFollowed
1 0.68 Safe CarAge < 10 & Country = Germany & Type = Compact & Price < 12822.5
2 0.76 Safe CarAge < 10 & Country = Korea & Type = Compact & Price > 12822.5
3 0.88 Unsafe CarAge > 10 & Type = Van & Country = USA & Price > 15988
Kind of code I'm looking for is
library(rpart)
fit <- rpart(Reliability~.,data=car.test.frame)
this is what needs to expanded into multiple lines possibly
predResults <- predict(fit, newdata = newcar, type= "GETPATTERNS")
The partykit
package has a function .list.rules.party()
which is currently unexported but can be leveraged to do what you want to do. The main reason that we haven't exported it, yet, is that its type of output may change in future versions.
To obtain the predictions you describe above you can do:
pathpred <- function(object, ...)
{
## coerce to "party" object if necessary
if(!inherits(object, "party")) object <- as.party(object)
## get standard predictions (response/prob) and collect in data frame
rval <- data.frame(response = predict(object, type = "response", ...))
rval$prob <- predict(object, type = "prob", ...)
## get rules for each node
rls <- partykit:::.list.rules.party(object)
## get predicted node and select corresponding rule
rval$rule <- rls[as.character(predict(object, type = "node", ...))]
return(rval)
}
Illustration using the iris
data and rpart()
:
library("rpart")
library("partykit")
rp <- rpart(Species ~ ., data = iris)
rp_pred <- pathpred(rp)
rp_pred[c(1, 51, 101), ]
## response prob.setosa prob.versicolor prob.virginica
## 1 setosa 1.00000000 0.00000000 0.00000000
## 51 versicolor 0.00000000 0.90740741 0.09259259
## 101 virginica 0.00000000 0.02173913 0.97826087
## rule
## 1 Petal.Length < 2.45
## 51 Petal.Length >= 2.45 & Petal.Width < 1.75
## 101 Petal.Length >= 2.45 & Petal.Width >= 1.75
(Only the first observation of each species is shown for brevity here. This corresponds to indexes 1, 51, and 101.)
And with ctree()
:
ct <- ctree(Species ~ ., data = iris)
ct_pred <- pathpred(ct)
ct_pred[c(1, 51, 101), ]
## response prob.setosa prob.versicolor prob.virginica
## 1 setosa 1.00000000 0.00000000 0.00000000
## 51 versicolor 0.00000000 0.97826087 0.02173913
## 101 virginica 0.00000000 0.02173913 0.97826087
## rule
## 1 Petal.Length <= 1.9
## 51 Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length <= 4.8
## 101 Petal.Length > 1.9 & Petal.Width > 1.7
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With