ONNX inference speed: To ensemble or not?

Pascal Niville
Ixor
Published in
3 min readSep 26, 2022

At Ixor we have several machine learning models in production, running on AWS ECR instances. The “time is money” statement can be taken literally in cloud computing, therefore we are always searching for options to speed up our production code. One option is by profiling our code and making everything more efficient, but another option is to optimize the size of our machine learning models.

Some deep learning principals that hold in general are:

  • Larger models score better than smaller models
  • Ensembles score better than single models

Given the above, we wanted to know if an ensemble of small models could outperform a larger single model and be faster at the same time.

Dall-E assisted impression of the ensemble vs large model hypothesis

Experiment setup

The model used for this experiment is a UNet with a classification head. The model size is determined by a “width scale” variable, which is basically a proxy for the channels per layer.

For inference, we export our models to ONNX files. A specialized format to run machine learning models independently from the frameworks they are trained with (like PyTorch). ONNX runtime is in general (much) faster than PyTorch, because it uses a static ONNX graph, allowing for optimizations that would be hard/impossible to do with PyTorch. In a sense, it’s similar to compiled vs interpreted programming language implementations. [1]

The reported inference duration is the average inference time on a testset of 500 samples in seconds.

The number of parameters in the network in function of the with scale parameter

The code we use to run our models in production:

# Initialisation phase: loading of models to memory
import
onnxruntime as ort
self.models = [ort.InferenceSession(p) for p in self.model_paths]
# Inference phase: run the models one by one
for
model in models:
model_output = model.run(None, {'input': input_tensor})

Results

The image below shows our results of the experiment. The processing time (s) is on the x-axis and the average label accuracy is on the y-axis. The color corresponds with the overall size of the model/ensemble, a model of size (width scale) 7 gets the value 7, while an ensemble of a model of size 2 and a model of size 3 is numbered as 23.

The inference time of a single model scales approximately linearly in relation to its size and so do the ensemble models. But surprisingly there is “a lot” of overhead in running two separate ONNX models. The largest and best scoring single model (size 7, containing 9 million parameters) is faster than the smallest and worst scoring ensemble (two models of size 2, holding 5 million parameters in total)!

Conclusion

If you are running your models with ONNX and inference time is important, it can be interesting to drop the ensembles and to go for one big model instead.

Update 30/09/2022

The conclusion is only partially true and depends on the implementation. In hindsight, the above described implementation isn’t the optimal for ensembling with ONNX. Instead you should wrap the smaller models in one large ensemble model, before exporting it to ONNX. This can be done as following:

class EnsembleModel(nn.Module):
def __init__(self, modelpaths: list):
super(EnsembleModel, self).__init__()
# loading in for loop gives error when converting to onnx
self.model1 = torch.load(modelpaths[0], map_location="cpu").eval()
self.model2 = torch.load(modelpaths[1], map_location="cpu").eval()

def forward(self, x):
return torch.stack([self.model1.forward(x)[0], self.model2.forward(x)[0]], axis=1).squeeze(0)

Running inference using this approach is twice as fast compared to the old implementation:

Inference
Using new EnsembleModel
Avg: 80.47ms
Old implementation
Avg: 172.27ms

Good news for ensembling!

At IxorThink, the machine learning practice of Ixor, we are constantly trying to improve our methods to create state-of-the-art solutions. As a software company, we can provide stable products from proof-of-concept to deployment. Feel free to contact us for more information.

Sources

[1] https://stackoverflow.com/questions/67943173/onnxruntime-vs-pytorch

--

--