Roll your own XGBoost model

Extract your trained XGBoost model into its own super fast, super light module, and understand a little more about our favorite machine learning algo along the way

Chris White
The Startup
5 min readMay 29, 2020

--

(Disclaimer: the Treelite library will convert XGBoost and other tree models into direct code with some clever optimizations. Our motivation here is to produce a custom implementation and to understand more about how XGBoost works)

We all know and love XGBoost, winner of competitions, robust in the face of missing data points, it’s as close to the Swiss army knife of machine learning algos as they come. Yes, there is competition from LightGBM, CatBoost and others, but we’ve been through so much together since the machine learning neolithic ages of 2014.

The day comes, though, when you want to understand more about our dark mistress. And that leads you to play around with the C++ function XGBoosterDumpModel, or model.dump_model if you’re coming at this from Python. The docs look so promising…

From https://xgboost.readthedocs.io/en/latest/python/python_api.html

It works! A lovely text file, fertile with promise, awaits you:

This looks simple enough. A set of trees, an easy-to-understand if-then structure. If feature 5 is less than 6.94 and feature 12 is less than 15 and feature 5 is less than 6.54 then the value of the tree is 2.11. And so on.

You then think to build a parser which takes these trees and implements them directly in your language of choice, in our case F#:

And a little fiddling later, you are able to take a trained XGBoost model, dump it to text and then re-implement in a fast, efficient and lightweight fashion, free of all the cruft and overhead. Pass in your data and nanoseconds later you should have predictions.

Except… where are the predictions? Do you just sum up all the tree terminal values? That doesn’t work. Well, it does when the objective function for the XGBoost model is a regression, such as “reg:squarederror”. But what about the default “binary:logistic”? And multi-class like “multi:softmax” and “multi:softprob”? The docs are silent on this, so we spend many hours on StackOverflow and Kaggle. Here are the answers you seek.

The Tree Structure

First, how many individual trees do we have and why?

The number of trees in your model will equal the number of estimators (“n_estimators” in the tree parameters) multiplied by the number of classes (“num_class). For a regression model or the default binary logistic model, the number of classes equals 1 and thus the number of trees will simply be equal to the number of estimators, understandably enough.

For multi-class classification, the number of trees will obviously be the number of estimators multiplied by the number of classes you are trying to predict.

Each tree contributes to the raw prediction of a given class. If you have 3 classes, then tree 0 will contribute to the raw prediction of class 0, tree 1 to class 1, tree 2 to class 2, and then looping back such that tree 3 contributes to the raw predictions of class 0. And so on.

Of course if you are using regression or a binary model, then the number of classes is 1 and thus each tree is contributing to the output of that class.

So each tree contributes to the raw prediction of the class whose number is equal to tree number modulo the number of classes, or tree_number % num_classes.

Raw Predictions -> Probabilities

Your raw predictions will thus obviously be an array of length num_class. In the case of regression and binary logistic, as we pointed out above, that length will be 1, which you can treat as simply being a scalar. For other objectives, it will equal the number of classes you are estimating.

For regression models, the raw prediction (the single length array result) is all you need! The single value which is the sum of the terminal value for each tree is the regression prediction for a given sample of features.

For other models the story is more complicated.

Binary Logistic models

As mentioned above, where the model objective is binary logistic (i.e. we are answering a yes/no question as to whether a given sample of features predicts an outcome or not) then the raw prediction will be a single length array.

This number represents the logit score for the probability that the prediction is yes.

To convert this score into a probability we need to apply the logistic function:

Which while looking complex boils down to 1. + (1. + Math.Exp(-x)) in F# or 1 / (1 + math.exp(-x)) in Python. This value is the probability that the features are predicting a true or yes outcome. Obviously the probability that the features are predicting a false or no outcome is simple 1 minus that number.

To convert this probability into a [0,1] binary result, simply select the larger probability. If the probability of a true outcome is 50% or greater, then we have a true prediction, otherwise false.

Easy!

Multi-Class models

For multi-class models, we have a raw prediction array of length number of classes, which are the summed terminal values of the trees which contribute to each class prediction. First we transform that array by applying the exponential function to each value, i.e. Math.Exp(x) in F# or math.exp(x) in Python. Note that we do not take the reciprocal here, nor do we take the negative of each number.

We then weight each transformed prediction by the sum of all the transformed predicted numbers to get the probability that the features are predicting membership of each class. I.e. if our raw output array were [ -2.4, -0.4, -0.8 ] then our transformed prediction array would be approximately [ 0.09, 0.67, 0.44] and our probability for class 0 would be 0.09 / sum(0.09, 0.76, 0.44) or 7.4%. And so on.

As for binary models, we can convert this into a binary class prediction simply by choosing the class with the highest probability. In our example above, this would be class 1 at 55%.

That’s all there is to it!

Summary

XGBoost is a wonderful workhorse that can produce robust predictions with the dirtiest of data and very little required in terms of preparation. The native C++ implementation, while fast for training, is sometimes too slow for a production environment. Unwrapping your trained model into a series of if-then statements offers a path to significant speed upgrades, and hopefully in the process gaining a little more insight into how this elegant model works.

--

--