Logo Questions Linux Laravel Mysql Ubuntu Git Menu

ROC curve from training data in caret





Using the R package caret, how can I generate a ROC curve based on the cross-validation results of the train() function?

Say, I do the following:

data(Sonar) ctrl <- trainControl(method="cv",    summaryFunction=twoClassSummary,    classProbs=T) rfFit <- train(Class ~ ., data=Sonar,    method="rf", preProc=c("center", "scale"),    trControl=ctrl) 

The training function goes over a range of mtry parameter and calculates the ROC AUC. I would like to see the associated ROC curve -- how do I do that?

Note: if the method used for sampling is LOOCV, then rfFit will contain a non-null data frame in the rfFit$pred slot, which seems to be exactly what I need. However, I need that for the "cv" method (k-fold validation) rather than LOO.

Also: no, roc function that used to be included in former versions of caret is not an answer -- this is a low level function, you can't use it if you don't have the prediction probabilities for each cross-validated sample.

like image 726
January Avatar asked Jun 30 '15 12:06


People also ask

How do you derive the ROC curve?

To make an ROC curve from your data you start by ranking all the values and linking each value to the diagnosis – sick or healthy. In the example in TABLE II 159 healthy people and 81 sick people are tested. The results and the diagnosis (sick Y or N) are listed and ranked based on parameter concentration.

What is the caret package in R?

Caret is a one-stop solution for machine learning in R. The R package caret has a powerful train function that allows you to fit over 230 different models using one syntax. There are over 230 models included in the package including various tree-based models, neural nets, deep learning and much more.

What is the 45 degree line in ROC curve?

This 45 degree line is interpreted as the 'flip a coin' line. That is, if the ROC curve isn't far from this 45 degree line, you might as well flip a coin to categorize a binary response. So the further away from the 45 degree line, the ROC curve trace is, the more sensitive (better than flipping a coin) the model is.

1 Answers

There is just the savePredictions = TRUE argument missing from ctrl (this also works for other resampling methods):

library(caret) library(mlbench) data(Sonar) ctrl <- trainControl(method="cv",                       summaryFunction=twoClassSummary,                       classProbs=T,                      savePredictions = T) rfFit <- train(Class ~ ., data=Sonar,                 method="rf", preProc=c("center", "scale"),                 trControl=ctrl) library(pROC) # Select a parameter setting selectedIndices <- rfFit$pred$mtry == 2 # Plot: plot.roc(rfFit$pred$obs[selectedIndices],          rfFit$pred$M[selectedIndices]) 


Maybe I am missing something, but a small concern is that train always estimates slightly different AUC values than plot.roc and pROC::auc (absolute difference < 0.005), although twoClassSummary uses pROC::auc to estimate the AUC. Edit: I assume this occurs because the ROC from train is the average of the AUC using the separate CV-Sets and here we are calculating the AUC over all resamples simultaneously to obtain the overall AUC.

Update Since this is getting a bit of attention, here's a solution using plotROC::geom_roc() for ggplot2:

library(ggplot2) library(plotROC) ggplot(rfFit$pred[selectedIndices, ],         aes(m = M, d = factor(obs, levels = c("R", "M")))) +      geom_roc(hjust = -0.4, vjust = 1.5) + coord_equal() 


like image 200
thie1e Avatar answered Sep 18 '22 12:09
