Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Help Understanding Cross Validation and Decision Trees

I've been reading up on Decision Trees and Cross Validation, and I understand both concepts. However, I'm having trouble understanding Cross Validation as it pertains to Decision Trees. Essentially Cross Validation allows you to alternate between training and testing when your dataset is relatively small to maximize your error estimation. A very simple algorithm goes something like this:

  1. Decide on the number of folds you want (k)
  2. Subdivide your dataset into k folds
  3. Use k-1 folds for a training set to build a tree.
  4. Use the testing set to estimate statistics about the error in your tree.
  5. Save your results for later
  6. Repeat steps 3-6 for k times leaving out a different fold for your test set.
  7. Average the errors across your iterations to predict the overall error

The problem I can't figure out is at the end you'll have k Decision trees that could all be slightly different because they might not split the same way, etc. Which tree do you pick? One idea I had was pick the one with minimal errors (although that doesn't make it optimal just that it performed best on the fold it was given - maybe using stratification will help but everything I've read say it only helps a little bit).

As I understand cross validation the point is to compute in node statistics that can later be used for pruning. So really each node in the tree will have statistics calculated for it based on the test set given to it. What's important are these in node stats, but if your averaging your error. How do you merge these stats within each node across k trees when each tree could vary in what they choose to split on, etc.

What's the point of calculating the overall error across each iteration? That's not something that could be used during pruning.

Any help with this little wrinkle would be much appreciated.

like image 746
chubbsondubs Avatar asked Feb 22 '10 22:02

chubbsondubs


People also ask

How do you explain cross-validation?

Cross-validation is a resampling method that uses different portions of the data to test and train a model on different iterations. It is mainly used in settings where the goal is prediction, and one wants to estimate how accurately a predictive model will perform in practice.

How do you find the best depth for a decision tree?

By re-sampling the data many times, splitting the into training and validation folds, fitting trees with different sizes on the training folds and looking at the classification accuracy on the validation folds, we are able to find the tree depth, which gives the best bias-variance trade-off.


1 Answers

The problem I can't figure out is at the end you'll have k Decision trees that could all be slightly different because they might not split the same way, etc. Which tree do you pick?

The purpose of cross validation is not to help select a particular instance of the classifier (or decision tree, or whatever automatic learning application) but rather to qualify the model, i.e. to provide metrics such as the average error ratio, the deviation relative to this average etc. which can be useful in asserting the level of precision one can expect from the application. One of the things cross validation can help assert is whether the training data is big enough.

With regards to selecting a particular tree, you should instead run yet another training on 100% of the training data available, as this typically will produce a better tree. (The downside of the Cross Validation approach is that we need to divide the [typically little] amount of training data into "folds" and as you hint in the question this can lead to trees which are either overfit or underfit for particular data instances).

In the case of decision tree, I'm not sure what your reference to statistics gathered in the node and used to prune the tree pertains to. Maybe a particular use of cross-validation related techniques?...

like image 198
mjv Avatar answered Sep 21 '22 12:09

mjv