Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Train test split in `r`'s `caret` package

Tags:

r

r-caret

I'm getting familiar with r's caret package, but, coming from other programming language, it thorougly confused me.

What I want to do now is a fairly simple machine learning workflow, which is:

  1. Take a training set, in my case the iris dataset
  2. Split it into a training and test set (a 80-20 split)
  3. For every k from 1 to 20, train the k nearest neighbor classifier on the training set
  4. Test it on the test set

I understand how to do the first part, since iris is already loaded. Then, the second part is done by calling

a <- createDataPartition(iris$Species, list=FALSE)
training <- iris[a,]
test <- iris[-a,]

Now, I also know that I can train the model by calling

library(caret)
knnFit <- train()
knnFit <- train(Species~., data=training, method="knn")

However, this will result in r already performing some optimisation on the parameter k. Of course, I can limit what values of k the method should try, with something like

knnFit <- train(Species~., data=training, method="knn", tuneGrid=data.frame(k=1:20))

which works just fine, but it still doesn't to exactly what I want it to do. This code will now do, for each k:

  1. take a bootstrap sample from the test.
  2. Asses the performance of the k-nn method using the given sample

What I want it to do:

  1. For each k, train the model on the same train set which I constructed earlier
  2. Asses the performance **on the same test set which I constructed earlier.

So I would need something like

knnFit <- train(Species~., training_data=training, test_data=test, method="knn", tuneGrid=data.frame(k=1:20))

but this of course does not work.

I understand I should do something with the trainControl parameter, but I see its possible methods are:

"boot", "boot632", "cv", "repeatedcv", "LOOCV", "LGOCV", "none"

and none of these seems to do what I want.

like image 494
5xum Avatar asked Mar 01 '16 08:03

5xum


People also ask

What does the caret package do in R?

Caret is a one-stop solution for machine learning in R. The R package caret has a powerful train function that allows you to fit over 230 different models using one syntax. There are over 230 models included in the package including various tree-based models, neural nets, deep learning and much more.

What package is the train function in R?

The train function in R caret package.


2 Answers

If I understand the question correctly, this can be done all within caret using LGOCV (Leave-group-out-CV = repeated train/test split) and setting the training percentage p = 0.8 and the repeats of the train/test split to number = 1 if you really want just one model fit per k that is tested on a testset. Setting number > 1 will repeatedly assess model performance on number different train/test splits.

data(iris)
library(caret)
set.seed(123)
mod <- train(Species ~ ., data = iris, method = "knn", 
             tuneGrid = expand.grid(k=1:20),
             trControl = trainControl(method = "LGOCV", p = 0.8, number = 1,
                                      savePredictions = T))

All predictions that have been made by the different models on the test set are in mod$pred if savePredictions = T. Note rowIndex: These are the rows that have been sampled into the test set. Those are equal for all different values of k, so the same training/test sets are used every time.

> head(mod$pred)
    pred    obs rowIndex k  Resample
1 setosa setosa        5 1 Resample1
2 setosa setosa        6 1 Resample1
3 setosa setosa       10 1 Resample1
4 setosa setosa       12 1 Resample1
5 setosa setosa       16 1 Resample1
6 setosa setosa       17 1 Resample1
> tail(mod$pred)
         pred       obs rowIndex  k  Resample
595 virginica virginica      130 20 Resample1
596 virginica virginica      131 20 Resample1
597 virginica virginica      135 20 Resample1
598 virginica virginica      137 20 Resample1
599 virginica virginica      145 20 Resample1
600 virginica virginica      148 20 Resample1 

There's no need to construct train/test sets manually outside of caret unless some kind of nested validation prodedure is desired. You can also plot the validation-curve for the different values of k by plot(mod).

like image 70
thie1e Avatar answered Sep 30 '22 02:09

thie1e


Please read through the caret website to see how everything works. Or read the book "Applied Predictive Modeling" written by Max Kuhn for more info on how caret works.

Roughly speaking, trainControl contains a diverse set of parameters for the train function, like cross-validation settings, metrics to apply (ROC / RMSE), sampling, preprocessing, etc.

In train you can set additional settings like grid searches. I extended your code example so it works. Make sure to check how createDataPartition works, because the default setting splits the data in half.

library(caret)

a <- createDataPartition(iris$Species, p = 0.8, list=FALSE)
training <- iris[a,]
test <- iris[-a,]

knnFit <- train(Species ~ ., 
                data = training, 
                method="knn",  
                tuneGrid=data.frame(k=1:20))

knn_pred <- predict(knnFit, newdata = test)

EDIT based on comment:

What you want is not possible with one train object. Train will use the tunegrid to find the best k and use that outcome in the finalModel. This finalModel will be used for making predictions.

If you want to have an overview of all k's you might not want to use caret's train function but write a function for yourself. Maybe something like below. Note that knn3 is a knn-model from caret.

k <- 20
knn_fit_list <- list()
knn_pred_list <- list()

for (i in 1:k) {
  knn_fit_list[[i]] <- knn3(Species ~ ., 
                            data = training, 
                            k = i)
  knn_pred_list[[i]] <- predict(knn_fit_list[[i]], newdata = test, type = "class")

}

knn_fit_list will contain all the fitted models for the specified number of k. knn_pred_list will contain all the predictions.

like image 20
phiver Avatar answered Sep 30 '22 02:09

phiver