Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I set sub-sample size in Random Forest Classifier in Scikit-Learn? Especially for imbalanced data

Currently, I am implementing RandomForestClassifier in Sklearn for my imbalanced data. I am not very clear about how RF works in Sklearn exactly. Here are my concerns as follows:

  1. According to the documents, it seemed that there is no way to set the sub-sample size (i.e smaller than the original data size) for each tree learner. But in fact, in random forest algo, we need to get both subsets of samples and subsets of features for each tree. I am not sure can we achieve that via Sklearn? If yes, how?

Follwoing is the description of RandomForestClassifier in Sklearn.

"A random forest is a meta estimator that fits a number of decision tree classifiers on various sub-samples of the dataset and use averaging to improve the predictive accuracy and control over-fitting. The sub-sample size is always the same as the original input sample size but the samples are drawn with replacement if bootstrap=True (default)."

Here I found a similar question before. But not many answers for this question.

How can SciKit-Learn Random Forest sub sample size may be equal to original training data size?

  1. For imbalanced data, if we could do sub-sample pick-up via Sklearn (i.e solve the question #1 above), can we do balanced-random forest? i.e. for each tree learner, it will pick up a subset from less-populated class, and also pick up the same number of samples from more-populated class to make up an entire training set with equal distribution of two classes. Then repeat the process for a batch of times (i.e. # of trees).

Thank you! Cheng

like image 302
Cheng Fang Avatar asked Jul 06 '17 17:07

Cheng Fang


People also ask

How do you deal with imbalanced data in random forest?

The solution is to use stratified sampling, ensuring splitting the data randomly and keeping the same imbalanced class distribution for each subset. The modified version of K-Fold i.e. stratified K-Fold Cross Validation necessitates the matching class distribution with the complete training dataset in each split.

Does random forest work with imbalanced data?

Again, random forest is very effective on a wide range of problems, but like bagging, performance of the standard algorithm is not great on imbalanced classification problems.

How many samples do you need for random forest?

It helps, if your data is very low noise. From 40-50 samples it starts getting better. 500 good.

What is the best n_estimators in random forest?

We may use the RandomSearchCV method for choosing n_estimators in the random forest as an alternative to GridSearchCV. This will also give the best parameter for Random Forest Model.


1 Answers

There is no obvious way, but you can hack into the sampling method in sklearn.ensemble.forest.

(Updated on 2021-04-23 as I found sklearn refactor the code)

By using set_rf_samples(n), you can force the tree to sub-sample n rows, and call reset_rf_samples() to sample the whole dataset.

for version < 0.22.0

from sklearn.ensemble import forest

def set_rf_samples(n):
    """ Changes Scikit learn's random forests to give each tree a random sample of
    n random rows.
    """
    forest._generate_sample_indices = (lambda rs, n_samples:
        forest.check_random_state(rs).randint(0, n_samples, n))

def reset_rf_samples():
    """ Undoes the changes produced by set_rf_samples.
    """
    forest._generate_sample_indices = (lambda rs, n_samples:
        forest.check_random_state(rs).randint(0, n_samples, n_samples))
  

for version >=0.22.0

There is now a parameter available https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestRegressor.html

max_samples: int or float, default=None

   If bootstrap is True, the number of samples to draw from X to train each base estimator.

   If None (default), then draw X.shape[0] samples.

   If int, then draw max_samples samples.

   If float, then draw max_samples * X.shape[0] samples. Thus, max_samples should be in the interval (0, 1).

reference: fast.ai Machine Learning Course

like image 144
mediumnok Avatar answered Nov 09 '22 21:11

mediumnok