How to convert an xgboost model to ONNX: A step-by-step guide

Ved Suhas Paranjape
Turo Engineering
Published in
7 min readSep 11, 2024
Onyx among gradient boosted trees 😆

Introduction

At Turo, all our production microservices are JVM (kotlin) based and the language of choice for our data science team is python, because of which deploying machine learning models in production was challenging. Python is the preferred stack by data scientists because of the rich machine learning libraries (sklearn, xgboost, PyTorch, etc.), faster prototyping and community support. We used the jpmml library in the past to integrate ML models in the java microservices but there were several shortcomings of the library such as low community support, error prone, etc. We considered various options to be able to train models in python and deploy in java microservices and ONNX proved to be the best suited for the job. Other options that we considered were Xgboost4J and h20.ai. ONNX proved to be the winner because of the following fortes:

  • Facility to run inference in batch mode (vector inference)
  • Wide range of machine learning model conversions supported
  • Inference speed (5–10x performance speed-up over pickle model as observed by in a real world experiment)
  • Great community support

One shortcoming of using onnx framework is that converting ML models to ONNX can be daunting to work with initially. However, the efforts are worth the benefits it provides with interoperability and speed of inference.

This document describes the general sequence of steps with examples that you can follow to train and convert a sklearn based machine learning model to ONNX.

Step 1: Get your Data

Let’s use the following toy dataset for our purpose:

F1-F4 are our input features and Y is the target variable.

Step 2: Feature Engineering

Let’s consider that we need to preprocess / transform our features before training the model. We need to apply the following transformations:

  1. Limit (clip) the values of feature F1 between 0.0 and 17.0
  2. Calculate the difference between F2 and F3 as F2-F3

When creating a preprocessing pipeline for the model, data scientists use sklearn.preprocessing package. Normally, when exporting the model to pickle or joblib format, we would have used a FunctionTransformer from sklearn.preprocessing to do the above feature transforms; but the sklearn-onnx library that we are gonna use further does not support it. For that reason, we need to write custom sklearn transformers as shown below.

import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin

class Clipper(BaseEstimator, TransformerMixin): # type: ignore
"""Apply clip transform to given input."""
def __init__(
self, lower_limit: int | float, upper_limit: int | float, column: str
) -> None:
"""Initialize transformer with expected columns."""
self.lower_limit = lower_limit
self.upper_limit = upper_limit
self.column = column
pass

def clip(self, data: int | float) -> np.ndarray:
"""Clip the given input within bounds."""
return np.clip(data, self.lower_limit, self.upper_limit)

def fit(self, X: pd.DataFrame | np.ndarray, y: None = None): # type: ignore
"""Fit the transformer."""
return self

def transform(self, X: pd.DataFrame | np.ndarray, y: None = None) -> np.ndarray:
"""Transform the given data."""
if type(X) == pd.DataFrame:
x = X.apply(lambda x: self.clip(x[self.column]), axis=1)
return x.values.reshape((-1, 1))
elif type(X) == np.ndarray:
vector_func = np.vectorize(self.clip)
x = vector_func(X)
return x.reshape((-1, 1))

def get_feature_names_out(self) -> None:
"""Return feature names. Required for onnx conversion."""
pass
class DifferenceCalculator(BaseEstimator, TransformerMixin):  # type: ignore
"""Apply overprice transform to given dataframe columns."""
def __init__(self, column_1: str, column_2: str):
"""Initialize transformer with expected columns."""
self.column_1 = column_1
self.column_2 = column_2
pass

def calculate_difference(self, x: int | float, y: int | float) -> int | float:
"""Difference calculator function."""
return (x - y)

def fit(self, X, y=None): # type: ignore
"""Fit the transformer."""
return self
def transform(self, X: pd.DataFrame | np.ndarray, y: None = None) -> np.ndarray:
"""Transform the given data."""
if type(X) == pd.DataFrame:
x = X.apply(
lambda x: calculate_difference(x[self.column_1], x[self.column_2]),
axis=1,
)
return x.values.reshape((-1, 1))

elif type(X) == np.ndarray:
vector_func = np.vectorize(self.calculate_difference)
x = vector_func(X)
return x.reshape((-1, 1))

def get_feature_names_out(self) -> None:
"""Return feature names. Required for onnx conversion."""
pass

You can find the list of sklearn-onnx supported transformers here.

Next, we use our custom transformers to write the preprocessing step for our model:

from sklearn.compose import ColumnTransformer
preprocesser = ColumnTransformer(
transformers=[
("F1", Clipper(0.0,17.0,"F1"), ["F1"]),
("F2-F3", DifferenceCalculator("F2", "F3"), ["F2", "F3"]),
("F4", "passthrough", ["F4"])]
verbose_feature_names_out=False,
)
#mapper.set_output(transform="pandas")

I added the last line and commented it out on purpose because the ONNX conversion pipeline in the step 6 fails if the output of the preprocessor is a pandas data frame. By default, it outputs a numpy array which is the correct way of doing this.

Because we need to work with numpy arrays, it is a good practice to write all your expected input features in the preprocessor step one by one to keep track of the order of your transformed input features.

Step 3: Train your sklearn model

We would be using a xgboost classifier model in this step. You can check the full list of sklearn models compatible with onnx here -https://onnx.ai/sklearn-onnx/supported.html

from sklearn.pipeline import Pipeline,make_pipeline

pipe = Pipeline([("preprocess",preprocesser),
("model",XGBClassifier(**hyperparams))])

pipe.fit(X_train, y_train) # X_train is your input features dataframe and y_train is target.

Step 4: Write onnx shape calculator and conversion functions for custom transformers

Next, we write a shape calculator and a conversion function for each custom transformer. The shape calculator tells onnx what would be the dimension of the output from the transformer.

from skl2onnx.common.utils import check_input_and_output_numbers

def difference_transformer_shape_calculator(operator): # type: ignore
"""Calculate output shape for difference transformer."""
check_input_and_output_numbers(operator, input_count_range=1, output_count_range=1)
# Gets the input type, the transformer works on any numerical type.
input_type = operator.inputs[0].type.__class__
# The first dimension is usually dynamic (batch dimension).
input_dim = operator.inputs[0].get_first_dimension()
operator.outputs[0].type = input_type([input_dim, 1])

def clipper_transformer_shape_calculator(operator): # type: ignore
"""Calculate output shape for clip transformer."""
check_input_and_output_numbers(operator, input_count_range=1, output_count_range=1)
# Gets the input type, the transformer works on any numerical type.
input_type = operator.inputs[0].type.__class__
# The first dimension is usually dynamic (batch dimension).
input_dim = operator.inputs[0].get_first_dimension()
operator.outputs[0].type = input_type([input_dim, 1])

Next, we write the onnx converters for the custom transformers. These are basically writing the operation in the transformer in terms of onnx operators.

from skl2onnx.algebra.onnx_ops import OnnxSlice, OnnxSub, OnnxDiv, OnnxMul, OnnxCastLike

def difference_transformer_converter(scope, operator, container): # type: ignore
"""Convert difference sklearn custom transformer to onnx."""
opv = container.target_opset
X = operator.inputs[0]
# create variables to access inputs tensor (X)
zero = np.array([0], dtype=np.int64)
one = np.array([1], dtype=np.int64)
two = np.array([2], dtype=np.int64)
hundred = np.array([100], dtype=np.float32)
# Get the 2 input values from X
# Slice(data, starts, ends, axes)
x0 = OnnxSlice(X, zero, one, one, op_version=opv)
x1 = OnnxSlice(X, one, two, one, op_version=opv)
# Write the difference transformer function below using onnx operators
z = OnnxSub(x0, x1, op_version=opv, output_names=operator.outputs[0])
z.add_to(scope, container)

def clipper_transformer_converter(scope, operator, container): # type: ignore
"""Convert clipper sklearn custom transformer to onnx."""
opv = container.target_opset
X = operator.inputs[0]
zero = np.array([0], dtype=np.int64)
one = np.array([1], dtype=np.int64)
lower_limit = np.array([operator.raw_operator.lower_limit], dtype=np.float32)
upper_limit = np.array([operator.raw_operator.upper_limit], dtype=np.float32)
# Slice(data, starts, ends, axes)
x0 = OnnxSlice(X, zero, one, one, op_version=opv)
z = OnnxClip(
x0,
lower_limit,
upper_limit,
op_version=opv,
output_names=operator.outputs[0],
)
z.add_to(scope, container)

Here is a list of all ONNX operators https://onnx.ai/onnx/operators/.

Step 5: Register shape calculators and converters

Next, we register the shape calculators and converter functions in the onnx scope for the sklearn onnx converter to be able to find these when converting to ONNX.

from skl2onnx import update_registered_converter, to_onnx, convert_sklearn

# For DifferenceCalculator
update_registered_converter(
DifferenceCalculator,
"DifferenceCalculator",
difference_transformer_shape_calculator,
difference_transformer_converter,
)

# For Clipper
update_registered_converter(
Clipper,
"Clipper",
clipper_transformer_shape_calculator,
clipper_transformer_converter,
)

Step 6: Convert to ONNX

The code below is converting the input data frame column to onnx compatible data types and then doing the actual onnx conversion.

from skl2onnx.common.data_types import FloatTensorType, Int64TensorType, StringTensorType,Int32TensorType,DoubleTensorType

def convert_dataframe_schema(df, drop=None):
"""Get onnx compatible datatypes for input dataframe"
inputs = []
for k, v in zip(df.columns, df.dtypes):
if drop is not None and k in drop:
continue
if v == "int64":
t = Int64TensorType([None, 1])
if v == "int32":
t = Int32TensorType([None, 1])
elif v == "float32":
t = FloatTensorType([None, 1])
elif v == "float64":
t = FloatTensorType([None, 1])
elif v == "object":
t = StringTensorType([None, 1])
else:
t = Int64TensorType([None, 1])
inputs.append((k, t))
return inputs

initial_inputs = convert_dataframe_schema(X_train)

# convert to onnx
model_onnx = convert_sklearn(
pipe, 'pipeline_xgboost',
initial_inputs, target_opset=18)

#export file
with open("model_xgboost.onnx", "wb") as f:
f.write(model_onnx.SerializeToString())

Think of the target opset as a version of sklearn-onnx library which is compatible with a particular version of the onnx runtime. For more information about target_opset parameter, please refer to http://onnx.ai/sklearn-onnx/auto_tutorial/plot_cbegin_opset.html

Step 7: Load ONNX Model and verify results

Next, we load the converted onnx model, run inference with both the sklearn and onnx models to verify if their outputs match. This is an essential test to make sure that the onnx conversion was successful.

from onnxruntime import InferenceSession
# read model and create session
sess = InferenceSession("model_xgboost.onnx", providers=["CPUExecutionProvider"])

Convert the input data frame into a python dictionary.

Example:

Dataframe -

f1 | f2 | f3 / f4

v1 | v2 | v3 / v4

v5 | v6 | v7 / v8

Dictionary:

{

“f1”: np.ndarray([[v1],[v5]),

“f2”: np.ndarray([[v2],[v6]),

“f3”: np.ndarray([[v3],[v7]),

“f4”: np.ndarray([[v4],[v8])

}

inputs = {}
num = 20
for c in X_train[:num].columns:
if X_train[c].dtype == np.float64:
inputs[c] = X_train[:num][c].values.reshape((-1, 1)).astype(np.float32)
elif X_train[c].dtype == np.int64:
inputs[c] = X_train[:num][c].values.reshape((-1, 1))
else:
inputs[c] = X_train[:num][c].values.reshape((-1, 1))

# infer
onnx_output = sess.run(None, inputs)
sklearn_output = pipe.predict_proba(X_train[:num])

# verify
onnx_model_preds = []
for output in onnx_output[1]:
onnx_model_preds.append(output[1])

diffs = np.asarray(onnx_model_preds) - sklearn_output

# average difference between onnx and pkl scores (should be as small as possible)
np.sum(diffs)/len(diffs)

Summary

To conclude, I would like to reiterate that converting ML models to ONNX can feel daunting at first but you get acquainted to the process gradually and it gets easier with practice. In the machine learning team at Turo, it was a collaborative effort in adopting ONNX framework. We started small by experimenting with converting one model and then taking the learnings further, sharing knowledge, which led to a successful transition and organization wide adoption.

More Resources

https://github.com/onnx/sklearn-onnx

https://onnx.ai/sklearn-onnx/pipeline.html

--

--