Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

caret rpart decision tree plotting result

I am training a decision tree model based on the heart disease data from Kaggle.

Since I am also building other models using 10-fold CV, I am trying to use caret package with rpart method to build the tree. However, the plot result is weird as "thalium" should be a factor. Why does it show "thaliumnormal <0.5"? Does this mean that if "thalium" == normal" then take the left route "yes", otherwise right route "no"?

Many thanks!

caret rpart decision tree plot using fancyRpartPlot

Edits: I apologize for not providing enough background info, which seemed to cause some confusion. "thalium" is a variable that represents a technique used to detect coronary stenosis (aka narrowing). It's a factor with three levels (normal, fixed defect, reversible defect).

data structure

In addition, I would like to make the graph more readable e.g. instead of "thaliumnormal < 0.5", it should be something like "thalium = normal". I could achieve this goal through using rpart directly (see below).

rpart decision tree plot

However, you probably have noticed that the tree is different, despite I used the recommended cp value with caret rpart CV 10 folds (see the code below).

code recommended cp, used for rpart tree using fancyRpartplot

I understand that these two packages may result in some differences. Ideally, I could use caret with method rpart to build the tree so that it aligns with other models built in caret. Does anyone know how I could make the plot label for the tree model built with caret rpart easier to understand?

like image 241
Rui Tongyu Avatar asked Jan 09 '20 04:01

Rui Tongyu


People also ask

What is Rpart plot?

Plot an rpart model. This function combines and extends plot. rpart and text. rpart in the rpart package. It automatically scales and adjusts the displayed tree for best fit.

What is the caret package in R?

Caret is a one-stop solution for machine learning in R. The R package caret has a powerful train function that allows you to fit over 230 different models using one syntax. There are over 230 models included in the package including various tree-based models, neural nets, deep learning and much more.


1 Answers

It would help if there were some data, like dput(head(data)) to show us what your data really looks like or a str(data) to show the levels of variables and data types.

But likely (without having seen it) the variable is thallium and one level is normal and the table has selected a LEVEL of the variable thallium and is evaluating, if something is that level normal or not.

The tree treats categorical variables as dummies by level and makes a decision based on being >= .5 or < .5 and 0 is always less and 1 is always more.

By design most tree algorithms choose the cut-off for each of the variables (including a dummy 0./1) that creates the most purity (moves the most observations to one side or another and closer to classification) and picks a point midway between the two values which create the greatest separation in groups.

With a binary variable, that split is at .5 because it is midway between the two different values a level can take 0 and 1.

like image 78
sconfluentus Avatar answered Sep 18 '22 15:09

sconfluentus