1

Assume you’re dealing with an imbalanced dataset. I know I can do things like upsampling, downsampling, and synthetic sampling to build out my train and test split. My question is: if I’m using a random forest classifier, are there any implementations in R or Python that would force each of the randomly generated trees that it will be evaluating against such that it has balanced classes?

Afflatus
  • 141
  • 6

3 Answers3

1

Apparently what this is describing is called "Balanced Random Forests."

There is a separate stack page that mentions a corresponding R package: Implementing Balanced Random Forest (BRF) in R using RandomForests

Afflatus
  • 141
  • 6
1

sklearn.ensemble.RandomForestClassifier accepts an argument class_weight that allows you to control how the samples are weighted, either globally or for each tree. In particular,

The “balanced_subsample” mode is the same as “balanced” except that weights are computed based on the bootstrap sample for every tree grown.

which seems to be exactly what you're asking about.

Sycorax
  • 76,417
  • 20
  • 189
  • 313
0

Maybe you can try to make trees subample balanced

def set_rf_balanced_subsampling(y_tt_labels):
    """ Changes Scikit learn's random forests to give each tree a balanced random sample of
    n random rows.
    """
    each_tree_class_samples = y_tt_labels.value_counts().min()

    indices = {}
    for tt_label in ["H", "S", "BC"]:
        indices[tt_label] = y_tt_labels[y_tt_labels == tt_label].index.values

    def balanced_sampling(rs, n_samples):
        return np.concatenate([forest.check_random_state(rs).choice(indices["BC"], each_tree_class_samples, replace=True),
                               forest.check_random_state(rs).choice(indices["H"], each_tree_class_samples, replace=True),
                               forest.check_random_state(rs).choice(indices["S"], each_tree_class_samples, replace=True)])

    forest._generate_sample_indices = balanced_sampling

Also, I recommend you to check the Imblearn library and combine the Pipeline methods with RandomOverSample methods.