Extracting trees from GBM models as data frames

Krzysztof Joachimiak
4 min readJun 11, 2023

--

Image from Freepik (by macrovector)

Popular Python libraries for GBM models mostly share their APIs, but they’re not identical. In this article, I’m focusing on methods used to extract individual decision trees from those models. I’m presenting first the native, package-specific solutions to do that. At the end, I’m introducing a new, common interface which allows us to export tree structures from arbitrary tree-based GBM models.

LightGBM

I’m starting with LightGBM, which provides a convenient trees_to_dataframe method. Let’s fit a model to present, how it looks like.

from lightgbm import LGBMRegressor
from sklearn.datasets import make_regression

# Fitting a model
X, y = make_regression()
lgb_reg = LGBMRegressor().fit(X, y)

# Extracting trees
lgb_reg.booster_.trees_to_dataframe()
Trees extracted from a LightGBM model

When I started working on common interface for tree extraction, I decided to follow more or less this data frame schema.

XGBoost

Fortunately, XGBoost has the same method in its API. I guess the idea of this functionality was copied to LightGBM, because XGBoost is an earlier project. As we can see, it has 11 columns, so less than its LightGBM equivalent.

from xgboost import XGBRegressor

# Fitting a model
X, y = make_regression()
xgb_reg = XGBRegressor().fit(X, y)

# Extracting trees
xgb_reg.get_booster().trees_to_dataframe()
Trees extracted from an XGBoost model

CatBoost

CatBoost doesn’t have its own version of trees_to_dataframe, but it shouldn’t surprise us too much. It’s because CatBoost typically uses so-called oblivious trees. They are symmetrical and apply one, shared split threshold per each tree level. Returning parameters of every single node is then highly redundant. The best way to get to the CatBoost trees in to export model with thesave_mode method. Alternatively, we can call one of the _get_tree_* methods to fetch some specific information directly from the model.

from catboost import CatBoostRegressor

# Fitting a model
X, y = make_regression()
cb_reg = CatBoostRegressor().fit(X, y, verbose=False)

# Saving model
cb_reg.save_model('cb.dump', 'json')

# Alternative - "private" methods
obj = cb_reg._object

# CatBoost general tree stats
obj._get_tree_count()
obj._get_tree_leaf_counts()

# Node values
obj._get_tree_leaf_values()
obj._get_tree_node_to_leaf()
obj._get_tree_splits()
obj._get_tree_step_nodes(0)

# Drawing an example of a CatBoost symmetrical tree
# We intentionally limit the maximum tree depth to have a clearer graph
cb_reg = CatBoostRegressor(depth=3).fit(X, y, verbose=False)
cb_reg.plot_tree(
tree_idx=0
).render(filename='cb_tree', format='png')
A CatBoost oblivious tree

scikit-learn

In scikit-learn’s GBM we can extract the full list of estimators and traverse them, but there’s no way to export them directly as a single data frame.

from sklearn.ensemble import GradientBoostingRegressor

# Fitting a model
X, y = make_regression()
gb_reg = GradientBoostingRegressor().fit(X, y)

# Extracting trees
gb_reg.estimators_

The common interface

scikit-gbm package offers a unified interface to extract trees as data frame from the following GBM models:

  • XGBoost
  • LightGBM
  • CatBoost
  • GradientBoosting* (scikit-learn)

Column names are kept among all the model-specific implementations, but not all the characterics are available for each implementation at the moment. They’re clearly listed in the scikit-gbm documentation.

trees_to_dataframe: list of available columns

How to use it? First, install the newest version of scikit-gbm using the command below.

pip install scikit-gbm -U

Then, import trees_to_dataframe function and use it as shown below.

from skgbm.tools import trees_to_dataframe

# Creating data
X, y = make_regression()

# Regressors
xgb_reg = XGBRegressor().fit(X, y)
lgb_reg = LGBMRegressor().fit(X, y)
cb_reg = CatBoostRegressor().fit(X, y, verbose=False)
gb_reg = GradientBoostingRegressor().fit(X, y)

# Getting trees as dataframe
xgb_df = trees_to_dataframe(xgb_reg)
lgb_df = trees_to_dataframe(lgb_reg)
cb_df = trees_to_dataframe(cb_reg)
gb_df = trees_to_dataframe(gb_reg)

# Example
gb_df
A data frame returned for a GradientBoostingRegressor mode

See also

Generating features with gradient boosted decision trees
The forgotten initial estimator in GBMs | by Krzysztof Joachimiak | Jun, 2023 | Medium

--

--