TDS Archive

An archive of data science, data analytics, data engineering, machine learning, and artificial intelligence writing from the former Towards Data Science Medium publication.

Speed Up PyTorch with Custom Kernels

We’ll begin with torch.compile, move on to writing a custom Triton kernel, and finally dive into designing a CUDA kernel

Alex Dremov
TDS Archive
Published in
5 min readJan 9, 2025

--

Read for free at alexdremov.me

PyTorch offers remarkable flexibility, allowing you to code complex GPU-accelerated operations in a matter of seconds. However, this convenience comes at a cost. PyTorch executes your code sequentially, resulting in suboptimal performance. This translates into slower model training, which impacts the iteration cycle of your experiments, the robustness of your team, the financial implications, and so on.

In this post, I’ll explore three strategies for accelerating your PyTorch operations. Each method uses softmax as our “Hello World” demonstration, but you can swap it with any function you like, and the discussed methods would still apply.

We’ll begin with torch.compile, move on to writing a custom Triton kernel, and finally dive into designing a CUDA kernel.

So, this post may get complicated, but bear with me.

torch.compile — A Quick Way to Boost Performance

💥 “Wait, you just turn on a single function call and it speeds up your code? That’s it? Sounds too good to be true.”

— Yes.

The torch.compile is a relatively new API in PyTorch that uses runtime graph capture and kernel fusion under the hood . With one decorator, you can often see speed improvements without significant changes to your code.

Speaking simply, for example, we can speed up calculations by merging operations into one GPU function, which removes overheads of separate GPU calls. Or even better, optimize a chain of operations by replacing them with one equivalent!

Such optimizations are not possible in the regular PyTorch execution mode (eager) as it executes operations just as they are called in the code.

Softmax Implementation with torch.compile

Below is a simple example showing how to implement and compile a softmax function using torch.compile. Replace it in your model’s forward pass, and your code (hopefully) runs faster.

--

--

TDS Archive
TDS Archive

Published in TDS Archive

An archive of data science, data analytics, data engineering, machine learning, and artificial intelligence writing from the former Towards Data Science Medium publication.

Alex Dremov
Alex Dremov

No responses yet