Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Different ways of implementing cross-validation for SVM model in MATLAB

Suppose that we have this code in MATLAB R2015b:

SVMModel = fitcsvm(INPUT, output,'KernelFunction','RBF','BoxConstraint',1);
CVSVMModel = crossval(SVMModel);
z = kfoldLoss(CVSVMModel)
  • In the first line using fitcsvm model trained by hole data. what is the purpose of setting Crossval to on in fitcsvm (as default we have 10-fold cross-validation with this option)? crossval and kfoldLoss using the same method as above? If yes why MATLAB documentation mentioned only this method not setting Crossval method for cross-validation? If these proceedings are the same how we can get the error rate using the first procedure?

  • When we want to predict feature (this is a prediction model) we need use model which trained with hole data (here it is SVMModel object)? So crossval and kfoldLoss are using for only calculating error we don't 10 trained model of this validation method for prediction. Is this is true? Is using whole data valid for neural network models?

like image 969
Eghbal Avatar asked Sep 25 '22 09:09

Eghbal


1 Answers

Regarding the first question. Both setting "CrossVal" to "on" and fetching the trained model to the crossval() function aim at the same thing. You can use one or the other, it's up to you.

kFoldLoss() is a function per-se, is not included in the "CrossVal" flag. It takes as input a cross-validated model. No matter if you cross-validated such model using the "CrossVal" flag in fitcsvm() or using the proper crossval() function. It is mandatory to use this function is you want to evaluate the error rate.

Regarding the second question now, the short answer is yes. You have to use the trained Support Vector Machine model as returned by fitcsvm(). The cross-validation procedure aims at validating your model, so you have an idea regarding its performances (and the 10-fold cross-validation is just one of the many methods available) but it does not perform any prediction. For that, you have to use the predict() function. I reckon you have a training set and a test set (or validation set) with their respective labels. With the training set, you train the SVM model whereas you use the validation set to perform the prediction phase. The main output of predict() is the vector of labels that model has predicted and you can match such predicted labels with the true labels of your validation set, to gather the error rate in validation.

I suggest you avoid the "CrossVal" flag, in this manner you have the situation under-control since you'll have:

  1. the trained model, output of fitcsvm()
  2. the cross-validated model, output of crossval(), and you can as well evaluate its performances with kFoldLoss()
  3. the predicted labels, using predict() with the trained model in step #1    
like image 93
AlessioX Avatar answered Oct 30 '22 17:10

AlessioX