Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Plotting decision tree results from tidymodels

I have managed to build a decision tree model using the tidymodels package but I am unsure how to pull the results and plot the tree. I know I can use the rpart and rpart.plot packages to achieve the same thing but I would rather use tidymodels as that is what I am learning. Below is an example using the mtcars data.

library(tidymodels)
library(rpart)
library(rpart.plot)
library(dplyr) #contains mtcars

#data
df <- mtcars %>%
    mutate(gear = factor(gear))


#train/test
set.seed(1234)

df_split <- initial_split(df)
df_train <- training(df_split)
df_test <- testing(df_split)


df_recipe <- recipe(gear~ ., data = df) %>%
  step_normalize(all_numeric())


#building model
tree <- decision_tree() %>%
   set_engine("rpart") %>%
   set_mode("classification")

#workflow
 tree_wf <- workflow() %>%
   add_recipe(df_recipe) %>%
   add_model(tree) %>%
   fit(df_train) #results are found here 

rpart.plot(tree_wf$fit$fit) #error is here

The error I get says Error in rpart.plot(tree_wf$fit$fit) : Not an rpart object which makes sense but I am unaware if there is a package or step I am missing to convert the results into a format that rpart.plot will allow me to plot. This might not be possible but any help would be much appreciated.

like image 544
Edgar Zamora Avatar asked Aug 21 '20 17:08

Edgar Zamora


People also ask

What does Rpart plot do?

It automatically scales and adjusts the displayed tree for best fit. It combines and extends the plot. rpart and text. rpart functions in the rpart package.

What is YVAL in decision tree?

yval is the predicted response at that node. For example, yval for node 1 (the root) is 38.23, which is the average response value for your training dataset. The values for nodes 16 and 17, the leaves, are -66.77 and -49.56, so these are the predicted values for any observations fallign into these nodes.


2 Answers

You can also use the workflows::pull_workflow_fit() function. It makes the code a little bit more elegant.

tree_fit <- tree_wf %>% 
  pull_workflow_fit()
rpart.plot(tree_fit$fit)
like image 159
hnagaty Avatar answered Sep 28 '22 19:09

hnagaty


The following works (note the extra $fit):

rpart.plot(tree_wf$fit$fit$fit)

Not a very elegant solution, but it does plot the tree.

Tested with parsnip 0.1.3 and rpart.plot 3.0.8.

like image 26
Stephen Milborrow Avatar answered Sep 28 '22 18:09

Stephen Milborrow