Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Get decision tree rule/path pattern for every row of predicted dataset for rpart/ctree package in R

I have built a decision tree model in R using rpart and ctree. I also have predicted a new dataset using the built model and got predicted probabilities and classes.

However, I would like to extract the rule/path, in a single string, for every observation (in predicted dataset) has followed. Storing this data in tabular format, I can explain prediction with reason in a automated manner without opening R.

Which means I want to got following.

ObsID   Probability   PredictedClass   PathFollowed 
    1          0.68             Safe   CarAge < 10 & Country = Germany & Type = Compact & Price < 12822.5
    2          0.76             Safe   CarAge < 10 & Country = Korea & Type = Compact & Price > 12822.5
    3          0.88           Unsafe   CarAge > 10 & Type = Van & Country = USA & Price > 15988

Kind of code I'm looking for is

library(rpart)
fit <- rpart(Reliability~.,data=car.test.frame)

this is what needs to expanded into multiple lines possibly

predResults <- predict(fit, newdata = newcar, type= "GETPATTERNS")
like image 790
arindam adak Avatar asked Apr 14 '15 03:04

arindam adak


1 Answers

The partykit package has a function .list.rules.party() which is currently unexported but can be leveraged to do what you want to do. The main reason that we haven't exported it, yet, is that its type of output may change in future versions.

To obtain the predictions you describe above you can do:

pathpred <- function(object, ...)
{
  ## coerce to "party" object if necessary
  if(!inherits(object, "party")) object <- as.party(object)

  ## get standard predictions (response/prob) and collect in data frame
  rval <- data.frame(response = predict(object, type = "response", ...))
  rval$prob <- predict(object, type = "prob", ...)

  ## get rules for each node
  rls <- partykit:::.list.rules.party(object)

  ## get predicted node and select corresponding rule
  rval$rule <- rls[as.character(predict(object, type = "node", ...))]

  return(rval)
}

Illustration using the iris data and rpart():

library("rpart")
library("partykit")
rp <- rpart(Species ~ ., data = iris)
rp_pred <- pathpred(rp)
rp_pred[c(1, 51, 101), ]
##       response prob.setosa prob.versicolor prob.virginica
## 1       setosa  1.00000000      0.00000000     0.00000000
## 51  versicolor  0.00000000      0.90740741     0.09259259
## 101  virginica  0.00000000      0.02173913     0.97826087
##                                           rule
## 1                          Petal.Length < 2.45
## 51   Petal.Length >= 2.45 & Petal.Width < 1.75
## 101 Petal.Length >= 2.45 & Petal.Width >= 1.75

(Only the first observation of each species is shown for brevity here. This corresponds to indexes 1, 51, and 101.)

And with ctree():

ct <- ctree(Species ~ ., data = iris)
ct_pred <- pathpred(ct)
ct_pred[c(1, 51, 101), ]
##       response prob.setosa prob.versicolor prob.virginica
## 1       setosa  1.00000000      0.00000000     0.00000000
## 51  versicolor  0.00000000      0.97826087     0.02173913
## 101  virginica  0.00000000      0.02173913     0.97826087
##                                                              rule
## 1                                             Petal.Length <= 1.9
## 51  Petal.Length > 1.9 & Petal.Width <= 1.7 & Petal.Length <= 4.8
## 101                        Petal.Length > 1.9 & Petal.Width > 1.7
like image 92
Achim Zeileis Avatar answered Sep 26 '22 11:09

Achim Zeileis