Scaling LightGBM with Dask

ODSC - Open Data Science
5 min readMar 11, 2021

LightGBM is an open-source framework for solving supervised learning problems with gradient-boosted decision trees (GBDTs). It ships with built-in support for distributed training, which just means “using multiple machines at the same time to train a model”. Distributed training can allow you to train on larger datasets, or can provide speedups that make it possible to train larger models.

In this article, you’ll learn how to use Python and Dask to take advantage of distributed LightGBM training.

First, some history!

If you just want to start running code and scaling LightGBM, this section can be skipped.

LightGBM was first released as an open-source project in August 2016. It was formally introduced to the machine learning community at the 2017 Neural Information Processing Systems conference (NIPS 2017). That original paper described several state-of-the-art features of LightGBM which made it faster than other GBDT approaches without sacrificing accuracy.

  • Bucketing continuous features into histograms to reduce the number of split points considered
  • Automatic optimizations for sparse data (“Exclusive Feature Bundling”)
  • Downsampling that considers samples’ gradients (“Gradient-based One-Side Sampling”)

That initial release of LightGBM also included a built-in framework for distributed training, based on research from a 2016 NIPS paper called “A Communication-Efficient Parallel Algorithm for Decision Tree.” LightGBM’s core library was implemented in C++, and from 2016 until now it has picked up official wrapper packages in R, Python, and Java, as well as many unofficial packages in other languages such as Julia, Go, and C#.

https://odsc.com/boston/

As of late 2018, LightGBM’s officially-supported distributed training interface was only supported in its CLI and in a limited way via the LightGBM Python package. These options required data scientists wanting to do distributed LightGBM training to either prepare training data in files on the worker machines beforehand or rely on LightGBM to move data over the network out to worker machines. To try to improve this experience, in October 2018 a small group of developers with no affiliation to the LightGBM project created dask-lightgbm, a package that allowed training on Dask DataFrames, and Dask Arrays. This was based on dask-xgboost, which had similarly been created in February 2017 to offer a Dask interface to XGBoost distributed training.

In September 2019, dask-xgboost was merged into the main XGBoost Python library. In November 2020, dask-lightgbm was merged into the main LightGBM Python library.

Ok, let’s train a model

LightGBM’s Dask interface ships as part of the `lightgbm` package, which can be installed by following the instructions from https://github.com/microsoft/LightGBM/tree/master/python-package#install-dask-package.

To use Dask features, you’ll need to start a Dask cluster and create a client for it. In the language of Dask, a “cluster” is a collection of processes that know how to do work (“workers”) and a process that tells them what to do (“scheduler”).

Image credit: James Bourbeau

LightGBM does not care what type of Dask cluster you use. For simplicity, this example uses a distributed.LocalCluster.

from dask.distributed import Client, LocalCluster, waitn_workers = 3
cluster = LocalCluster(n_workers=n_workers)
client = Client(cluster)
client.wait_for_workers(n_workers)

While you work through this example, you can monitor the resource utilization in the Dask cluster using the Dask diagnostic dashboard.

print(f"View the dashboard: {cluster.dashboard_link}")

The code below sets up some random training data as Dask Arrays [7]. These look like numpy arrays, but are distributed in multiple smaller pieces called “chunks”.

import dask.array as danum_rows = 1e6
num_features = 1e2
num_partitions = 10
rows_per_chunk = num_rows / num_partitions
data = da.random.random(
size=(num_rows, num_features),
chunks=(rows_per_chunk, num_features)
)
labels = da.random.random(
size=(num_rows, 1),
chunks=(rows_per_chunk, 1)
)

LightGBM’s Dask module has estimators with similar interfaces to the standard `lightgbm` `scikit-learn` API. You can pass any of the standard LightGBM parameters to these estimators to control the learning process.

import lightgbm as lgbdask_reg = lgb.DaskLGBMRegressor(
max_depth=5,
learning_rate=0.1,
tree_learner="data_parallel",
n_estimators=100,
min_child_samples=1,
)
dask_reg.fit(data, labels)

When you call .fit(), LightGBM will start up one training task on each worker. Each worker will train only on the chunks of the data that it has locally, so LightGBM will never waste time and memory shuffling your training data between workers.

For more details on how distributed LightGBM training works, see “Optimization in Parallel Learning” in the LightGBM documentation.

The model object produced by training is an instance of lgb.DaskLGBMRegressor. If you don’t want to have Dask as a dependency when you deploy this model, you can get a regular lightgbm.sklearn.LGBMRegressor from it with .to_local().

local_reg = dask_reg.to_local()
print(type(local_reg))

For an even lower-level model object, you can also extra a LightGBM Booster from the fitted model object.

booster = dask_reg.booster_
print(type(booster))

These objects can be saved in a binary format with cloudpickle, joblib, or pickle.

import cloudpicklewith open("model.pkl", "wb") as f:
cloudpickle.dump(dask_reg, f)

Evaluate Your Model

Finally, you can also use tools in the Dask ecosystem for model evaluation. The .predict() methods on the LightGBM Dask estimators produce predictions in Dask Array or Dask DataFrame format. The dask-ml project has an API similar to scikit-learn, but designed for use with data stored in Dask Array or Dask DataFrame format.

from dask_ml.metrics import mean_absolute_error
preds = dask_reg.predict(
X=data,
)
mean_absolute_error(labels, preds)

Conclusion

In this short article, you learned about the new LightGBM Dask module. You learned about the history of open source projects that led to this module, and you walked through a brief tutorial where you learned how to use it.

To learn more about LightGBM distributed training with Dask, see my upcoming ODSC talk, “Scaling Machine Learning with Dask” or our tutorial on “LightGBM Training with Dask” in Saturn Cloud’s documentation. And of course, stop by https://github.com/microsoft/LightGBM if you’d like to request a feature, report a bug, or contribute!

--

--

ODSC - Open Data Science

Our passion is bringing thousands of the best and brightest data scientists together under one roof for an incredible learning and networking experience.