10

Scikit learn seems to use probabilistic prediction instead of majority vote for the model aggregation technique without an explanation as to why (1.9.2.1. Random Forests).

Is there a clear explanation for why? Further is there a good paper or review article for the various model aggregation techniques that can be used for Random Forest bagging?

Thanks!

user1745038
  • 256
  • 1
  • 3
  • 10

1 Answers1

10

Such questions are always best answered by looking at the code, if you're fluent in Python.

RandomForestClassifier.predict, at least in the current version 0.16.1, predicts the class with highest probability estimate, as given by predict_proba. (this line)

The documentation for predict_proba says:

The predicted class probabilities of an input sample is computed as the mean predicted class probabilities of the trees in the forest. The class probability of a single tree is the fraction of samples of the same class in a leaf.

The difference from the original method is probably just so that predict gives predictions consistent with predict_proba. The result is sometimes called "soft voting", rather than the "hard" majority vote used in the original Breiman paper. I couldn't in quick searching find an appropriate comparison of the performance of the two methods, but they both seem fairly reasonable in this situation.

The predict documentation is at best quite misleading; I've submitted a pull request to fix it.

If you want to do majority vote prediction instead, here's a function to do it. Call it like predict_majvote(clf, X) rather than clf.predict(X). (Based on predict_proba; only lightly tested, but I think it should work.)

from scipy.stats import mode
from sklearn.ensemble.forest import _partition_estimators, _parallel_helper
from sklearn.tree._tree import DTYPE
from sklearn.externals.joblib import Parallel, delayed
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted

def predict_majvote(forest, X):
    """Predict class for X.

    Uses majority voting, rather than the soft voting scheme
    used by RandomForestClassifier.predict.

    Parameters
    ----------
    X : array-like or sparse matrix of shape = [n_samples, n_features]
        The input samples. Internally, it will be converted to
        ``dtype=np.float32`` and if a sparse matrix is provided
        to a sparse ``csr_matrix``.
    Returns
    -------
    y : array of shape = [n_samples] or [n_samples, n_outputs]
        The predicted classes.
    """
    check_is_fitted(forest, 'n_outputs_')

    # Check data
    X = check_array(X, dtype=DTYPE, accept_sparse="csr")

    # Assign chunk of trees to jobs
    n_jobs, n_trees, starts = _partition_estimators(forest.n_estimators,
                                                    forest.n_jobs)

    # Parallel loop
    all_preds = Parallel(n_jobs=n_jobs, verbose=forest.verbose,
                         backend="threading")(
        delayed(_parallel_helper)(e, 'predict', X, check_input=False)
        for e in forest.estimators_)

    # Reduce
    modes, counts = mode(all_preds, axis=0)

    if forest.n_outputs_ == 1:
        return forest.classes_.take(modes[0], axis=0)
    else:
        n_samples = all_preds[0].shape[0]
        preds = np.zeros((n_samples, forest.n_outputs_),
                         dtype=forest.classes_.dtype)
        for k in range(forest.n_outputs_):
            preds[:, k] = forest.classes_[k].take(modes[:, k], axis=0)
        return preds

On the dumb synthetic case I tried, predictions agreed with the predict method every time.

Danica
  • 21,852
  • 1
  • 59
  • 115
  • Great answer, Dougal! Thanks for taking the time to explain this carefully. Please consider also going over to stack overflow and answering [this question there](http://stackoverflow.com/questions/26899274/scikit-learn-randomforestclassifier-probabilistic-prediction-vs-majority-vote). – user1745038 Apr 28 '15 at 16:45
  • 1
    There's also a paper, [here](http://statistics.berkeley.edu/sites/default/files/tech-reports/421.pdf), which addresses probabilistic prediction. – user1745038 Apr 28 '15 at 16:47