Scikit-learn Tree Design

Scikit-learn has a well designed, extensible implementation of decision trees. The design pays dividends for techniques that rely on decision trees such as random forest. Overall, the design is a great example of thoughtfully designed machine learning code and is worth learning from.

Before jumping into the design, lets first review what decision trees, random forests and extra trees are. If this is not new to you, jump ahead to Scikit-learn Design section.

What are Decision Trees

Decision trees are a machine learning technique which recursively splits a feature into two regions such that each region is easier to predict than the full data together. Splits are selected along a feature where a threshold determines the two regions. Datums with the feature value less than the threshold are in region and features with a value greater than the threshold are another region.

First split For the problem of trying to classify the orange vs blue points, a split is made on the feature along the X axis at a threshold of 0. Subsequent splits will operate on the left and right regions recursively.

The algorithm can be viewed as a tree, where at each node the algorithm searches for a feature to split on and what the threshold for that feature should be. The feature and threshold selected will be the one that, if the split was made, would reduce a desired loss function the most.

The split produces two new nodes where datums with the feature below the threshold go to the left node and datums with the feature above the threshold go to the right node. For these two nodes, the algorithm repeats until a stopping condition is met. When the stopping condition is met, a leaf node is created which produces a prediction based on the datums in the leaf. For a regression problem a common choice is to use mean squared error for the loss function and predict the mean of the data in the leaf node.

Tree The tree shows the feature and threshold it used for each split and how it recursively narrowed down the space to make more accurate predictions

Decision trees produce easily interpretable models. To predict a new datum, one may simply follow the splits down the tree to determine the prediction. Further, domain experts can inspect the splits to ensure they match intuition.

Unfortunately, decision trees don’t produce especially good models and their predictions have high variance because changing the initial training data slightly can result in a drastically different model.

Fortunately, the fact decision trees are high variance weak learners means they work especially well in an ensembling technique called bagging.

Bagging

Bagging is a technique which trains many weak learners on different bootstrapped samples of the original data and averages the weak learner’s predictions. Assuming the weak learners are uncorrelated, averaging their predictions reduces the variance of the overall predictions on the order of \(\frac{1}{sqrt(O(T))}\) where \(T\) is the number of trees. Even better, adding more weak learners won’t lead to overfitting.

Random Forest and extra trees are two techniques that use bagging with decision trees and differ only in how the trees search for splits.

Random Forests

Random forest searches for the best split at each node. This means it will try all possible (feature, threshold) split options and select the one that reduces the loss the most.

Extra Trees

Extra trees samples from the possible (feature, threshold) split options and selects the one that reduces the loss the most. This has the benefits of being faster to construct the trees and provides regularization while selecting splits.

Scikit-learn Design

The main class for Scikit-learns design is the DecisionTree. The DecisionTree extends BaseEstimator and is used by both RandomForest and ExtraTrees as the weak learner implementation. DecisionTree can be configured for either regression or classification, and search for the best split (to support RandomForest) or the best random split (to support ExtraTrees).

DecisionTree maintains the tree of nodes where each node encapsulates the feature and threshold used to split the node and prediction for the leaf nodes. While constructing the tree, DecisionTree relies on a component called a Splitter to search for the feature and threshold to split on. The Splitter uses a component called the Criterion to evaluate splits to find the split that best minimizes the impurity function associated with the given criterion. Both components are written in Cython because they can be called millions of times while fitting.

While at a high level the logic is the following, the design and implementation of the Splitter and Criterion allows for far great performance.

for feature, threshold in possible_splits:
    criterion.evaluate(feature, threshold)
return best_split

Design of the Splitter

The Splitter decides how to search for the feature and threshold to split the node on. Scikit-learn has two different splitters, BestSplitter and RandomSplitter. BestSplitter will search for the best possible split across features and threshold while RandomSplitter will sample features and threshold and select the best split of the ones sampled. BestSplitter is used by RandomForest and RandomSplitter is used by ExtraTrees.

A Splitter has a public function, node_split that produces a SplitRecord which encapsulates a feature and threshold. While the two splitters have different searching algorithms, they both maintain a list of features that are constant from ancestor nodes so they don’t need to check those features at each split descending down the tree. For each feature and threshold, the Splitter calls into the Criterion to first update the Criterion for the current split and then to produce a low cost proxy for the impurity of the split. Once the best split is selected, the Splitter resets the Criterion to the best split and computes the actual impurity for the split as well as the impurity for left and right children.

The psuedo-code is:

for feature in possible_features:
    for threshold, pos in possible_thresholds(feature):
        criterion.update(pos)
        criterion.proxy_impurity_improvement()
criterion.update(best_split.pos)
criterion.impurity_improvement()
return best_split

A lot of care is taken in the criterion to minimize computation so tree building is fast.

Design of the Criterion

The Criterion computes the impurity of a given split and produces the values the node will return for predictions. There are regression and classification specific Criterion that both fit within the same interface. It is critical that the Criterion is fast because it is called within the inner loop.

The Criterion was designed so that for a given feature, state can be maintained so that changing the split threshold from \(t_1\) to \(t_2\) requires computation on the order of the number of elements that moved from the right to the left child instead of the total number of elements in the node. This is crucial when using BestSplitter because BestSplitter will try all possible threshold values; recomputing the loss without state would be intolerably slow. As an example, the default regression Criterion, mse (mean squared error), maintains the count, sum and sum squared of the elements the left child and the total for the node so that changing the threshold requires adding or subtracting from the left child’s state and then computing the right child’s state requires subtracting the left child’s from the total’s.

The Criterion provides two functions to compute the impurity for a split; proxy_impurity_improvement and impurity_improvement. proxy_impurity_improvement is used while searching for the best split and implementations are able to drop any constant or proportional terms to reduce the cost of evaluating the split. When the best split is selected, impurity_improvement is called to get the correct value.

Finally, for leaf nodes, Criterion has a function, node_value, which returns the value the node predicts.

Scikit-learn provides mean squared error, mean absolute error, Friedman mean squared error as criterion for regression and gini and entropy for classification.

RandomForest and ExtraTrees

The extensibility of DecisionTree shows it true benefit when considering ensemble estimators that build off of DecisionTree; RandomForest and ExtraTrees.

RandomForest configures the splitter to be BestSplitter and allows the user to select the criterion desired. Similarly, ExtraTrees configures the splitter to be RandomSplitter and users may use the same criterion they selected for RandomForest.

Extending Scikit-learn Tree Design with new Criterion

While you can pass in a custom Criterion, extending the Criterion interface is difficult because Scikit-learn doesn’t provide the cython header file equivalent, .pxd files. In order to make them available, one most either implement the Criterion within the Scikit-learn package or manually copy the .pxd files to the proper Scikit-learn install directory so they can be imported.

Conclusion

Scikit-learn’s trees is an example of designing a general implementation that relies on composition allow for extending the tool for many applications.

Please send any thoughts, questions and feedback to wyegelwel at gmail.com or on twitter @wyegelwel