Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

R: Extracting Rules from a Decision Tree

I am working with the R programming language. Recently, I read about a new decision tree algorithm called "Reinforcement Learning Trees" (RLT) which supposedly has the potential to fit "better" decision trees to a dataset. The documentation for this library is available over here: https://cran.r-project.org/web/packages/RLT/RLT.pdf

I tried to use this library to run a classification decision tree on the (famous) Iris Dataset:

library(RLT)
data(iris)
fit = RLT(iris[,c(1,2,3,4)], iris$Species, model = "classification", ntrees = 1)

Question: From here, is it possible to extract the "rules" from this decision tree?

For example, if you use the CART Decision Tree model:

library(rpart)
library(rpart.plot)

fit <-rpart( Species ~. , data = iris)
rpart.plot(fit)

enter image description here

 rpart.rules(fit)

    Species  seto vers virg                                               
     setosa [1.00  .00  .00] when Petal.Length <  2.5                     
 versicolor [ .00  .91  .09] when Petal.Length >= 2.5 & Petal.Width <  1.8
  virginica [ .00  .02  .98] when Petal.Length >= 2.5 & Petal.Width >= 1.8

Is it possible to do this with the RLT library? I have been reading the documentation for this library and can not seem to find a direct way to extract the decision rules. I understand that this library is typically meant to be used as a substitute for the random forest (which do not have decision rules) - but I was reading the original paper for this algorithm where they specify that the RLT algorithm fit individual decision trees (via the RLT algorithm) and then aggregates them together as in random forest. Thus on some level, the RLT algorithm is able to fit an individual decision tree - which in theory should have "decision rules".

Does anyone know how to extract these rules?

Thanks!

References:

  • https://www.researchgate.net/publication/277625959_Reinforcement_Learning_Trees
like image 941
stats_noob Avatar asked Nov 02 '21 02:11

stats_noob


People also ask

How do you extract rules from the decision tree?

To extract rules from a decision tree, one rule is created for each path from the root to a leaf node. Each splitting criterion along a given path is logically ANDed to form the rule antecedent (“IF” part). The leaf node holds the class prediction, forming the rule consequent (“THEN” part).

What is the Rpart method in R?

Rpart is a powerful machine learning library in R that is used for building classification and regression trees. This library implements recursive partitioning and is very easy to use.

How do you extract rules from a decision tree?

To extract rules from a decision tree, one rule is created for each path from the root to a leaf node. Each splitting criterion along a given path is logically ANDed to form the rule antecedent (“IF” part). The leaf node holds the class prediction, forming the rule consequent (“THEN” part).

What are the types of decision tree in R?

The Decision tree in R uses two types of variables: categorical variable (Yes or No) and continuous variables. The terminologies of the Decision Tree consisting of the root node (forms a class label), decision nodes (sub-nodes), terminal node (do not split further).

What is a decision tree in research?

A decision tree is a flowchart-like tree structure in which the internal node represents feature (or attribute), the branch represents a decision rule, and each leaf node represents the outcome. A Decision Tree consists of, Nodes: Test for the value of a certain attribute.

How do you convert a decision tree to a classification system?

The decision tree can be converted to classification IF-THEN rules by tracing the path from the root node to each leaf node in the tree. The rules extracted are R4: IF age = senior AND credit rating = excellent THEN buys computer = yes R5: IF age = senior AND credit rating = fair THEN buys computer = no


Video Answer


1 Answers

The rules are stored in fit$FittedTrees[[1]] in a tabular format that is relatively difficult to interpret.

I have constructed a rather long function for you that will extract the rules as a data frame and additionally plot the tree as a ggplot if requested.

RLT_tree <- function(RLT_obj, plot = TRUE)
{
  
  tree <- as.data.frame(t(RLT_obj$FittedTrees[[1]]))
  tree <- tree[c(2, 3, 5, 6, 8, 9, grep("Class\\d", names(tree)))]
  class_cols <- grep("Class\\d", names(tree))
  names(tree)[class_cols] <-
    RLT_obj$ylevels[1 + as.numeric(sub("Class(\\d+)", "\\1",
                                   names(tree)[class_cols]))]
  tree$variable <- RLT_obj$variablenames[tree$SplitVar1]
  tree$variable[is.na(tree$variable)] <- "(Leaf node)"
  tree$rule <- tree$variable
  tree$depth <- numeric(nrow(tree))
  tree$rightness <- numeric(nrow(tree))
  tree$group <- character(nrow(tree))
  
  walk_tree <- function(node, depth, rightness, node_label = "Start", group = "S")
  {
    new_row <- tree[which(tree$Node == node),]
    new_row$depth <- depth
    new_row$rightness <- rightness
    left_label <- paste(new_row$variable, new_row$SplitValue, sep = " < ")
    right_label <- paste(new_row$variable, new_row$SplitValue, sep = " > ")
    new_row$variable <- paste(node_label, "\nn = ", new_row$NumObs)
    new_row$rule <- node_label
    if(is.nan(new_row$SplitValue)) {
      n_objs <- round(new_row[,class_cols] * new_row$NumObs)
      classify <- paste((names(tree)[class_cols])[n_objs != 0], 
                        n_objs[n_objs != 0],
                        collapse = "\n")
      new_row$variable <- paste(new_row$variable, classify, sep = "\n")
    }
    new_row$group <- group
    tree[which(tree$Node == node),] <<- new_row
    if(!is.nan(new_row$SplitValue)){
      walk_tree(new_row$NextLeft, depth + 1, rightness - 2/(depth/2 + 1), 
                left_label, paste(group, "L"))
      walk_tree(new_row$NextRight, depth + 1, rightness + 2/(depth/2 + 1), 
                right_label, paste(group, "R"))
    }
  }
  
  walk_tree(0, 0, 0)
  tree$depth <- max(tree$depth) - tree$depth
  tree$type <- is.nan(tree$NextLeft)
  tree$group <- substr(tree$group, 1, nchar(tree$group) - 1)

  if(plot)
  {
  print(ggplot(tree, aes(rightness, depth)) + 
    geom_segment(aes(x = rightness, xend = rightness, 
                     y = depth, yend = depth - 1, alpha = type)) + 
    geom_line(aes(group = group)) +
    geom_label(aes(label = variable, fill = type), size = 4) + 
    theme_void() + 
    scale_x_continuous(expand = c(0, 1)) + 
    suppressWarnings(scale_alpha_discrete(range = c(1, 0)))  +
    theme(legend.position = "none"))
  }
  tree$isLeaf <- is.nan(tree$NextLeft)
  tree[c(match(c("Node", "rule", "depth", "isLeaf"), names(tree)), class_cols)]
}

and this allows:

df <- RLT_tree(fit, plot = TRUE)

enter image description here

and

df
#>    Node               rule depth isLeaf    setosa versicolor virginica
#> 1     0              Start     6  FALSE 0.3111111 0.34814815 0.3407407
#> 2     1  Sepal.Width < 3.2     5  FALSE 0.1573034 0.51685393 0.3258427
#> 3     2  Sepal.Width > 3.2     5  FALSE 0.6086957 0.02173913 0.3695652
#> 4     3 Sepal.Length < 5.4     4  FALSE 0.7000000 0.30000000 0.0000000
#> 5     4 Sepal.Length > 5.4     4   TRUE 0.0000000 0.57971014 0.4202899
#> 6     5 Petal.Length < 1.3     3   TRUE 1.0000000 0.00000000 0.0000000
#> 7     6 Petal.Length > 1.3     3  FALSE 0.6000000 0.40000000 0.0000000
#> 8     7 Petal.Length < 1.4     2   TRUE 1.0000000 0.00000000 0.0000000
#> 9     8 Petal.Length > 1.4     2  FALSE 0.5000000 0.50000000 0.0000000
#> 10    9 Petal.Length < 3.9     1  FALSE 0.7500000 0.25000000 0.0000000
#> 11   10 Petal.Length > 3.9     1   TRUE 0.0000000 1.00000000 0.0000000
#> 12   11 Sepal.Length < 4.9     0   TRUE 1.0000000 0.00000000 0.0000000
#> 13   12 Sepal.Length > 4.9     0   TRUE 0.0000000 1.00000000 0.0000000
#> 14   13  Petal.Width < 0.2     4   TRUE 1.0000000 0.00000000 0.0000000
#> 15   14  Petal.Width > 0.2     4  FALSE 0.3793103 0.03448276 0.5862069
#> 16   15 Sepal.Length < 5.7     3   TRUE 1.0000000 0.00000000 0.0000000
#> 17   16 Sepal.Length > 5.7     3  FALSE 0.0000000 0.05555556 0.9444444
#> 18   17  Sepal.Width < 3.3     2   TRUE 0.0000000 0.00000000 1.0000000
#> 19   18  Sepal.Width > 3.3     2  FALSE 0.0000000 0.08333333 0.9166667
#> 20   19 Petal.Length < 6.1     1  FALSE 0.0000000 0.11111111 0.8888889
#> 21   20 Petal.Length > 6.1     1   TRUE 0.0000000 0.00000000 1.0000000
#> 22   21 Sepal.Length < 6.3     0   TRUE 0.0000000 0.16666667 0.8333333
#> 23   22 Sepal.Length > 6.3     0   TRUE 0.0000000 0.00000000 1.0000000

To show this works in the more general case, we can also do:

fit2 = RLT(mtcars[,1:3], factor(rownames(mtcars)), model = "classification", ntrees = 1)

df <- RLT_tree(fit2)

enter image description here

like image 94
Allan Cameron Avatar answered Oct 18 '22 05:10

Allan Cameron