A Trip to Kernels: Understanding PyTorch’s Internal Architecture
If you’re here, you know that PyTorch is one of the most popular libraries among deep learning practitioners. It is highly efficient, and its flexibility makes it easy to write custom code when algorithms lack certain functionalities required in your projects.
However, working with PyTorch’s codebase can be intimidating, especially if you’re used to working solely with high-level abstractions. In this blog post, we will dive into PyTorch’s internal architecture. Specifically, we’ll explore two critical components that facilitate efficient computations: Tensor Wrappers and Kernels.
We will cover:
- Tensor Wrapper (e.g. batch tensors)
- Anatomy of an operator call
- Writing kernels
- On more efficient development workflow
- Additional resources
Let’s begin 🚀
Tensor Wrapper (e.g. batch tensors)
PyTorch Tensor has well-designed metadata that decouples logical memory layout from physical memory and provides primitives for operator dyamic dispatch from the Python frontend to the high-performance C++ backend. Every tensor consists of a device
, layout
, and dtype
metadata.
device
allows per-device ops (run in kernels) dispatch for better efficiency.layout
allows sharing the same hardware-level memory allocation for multiple views, decoupling logical tensor layout from physical memory layout. When accessing the actual physical memory is required,TensorAccessor
will leverage layout (e.g. stride) to access tensor elements. You can learn more aboutTensorAccessor
from this blog or this awesome PyTorch internals podcast.dtype
allows per-dtype ops (run in kernesl) dispatch for better efficiency via theAT_DISPATCH_ALL_TYPES
macro inaten
kernels.
Anatomy of an operator call
When developing a custom operator with PyTorch, multiple steps are involved before any computation occurs. This process involves Python argument parsing, variable type switching, data type switching, and eventually, kernel dispatch.
📧 Note: The code to pass the dispatch is auto-generated from aten/src/ATen/native/native_functions.yaml
Step 1: Python argument parsing
- You are still in python land 🥳
- Purpose of this step is to pass Python argument to C++ bindings via
torch/csrc
- For example:
torch.add
→THPVariable_add
(attorch._C.VariableFunctions
, auto-generated)
Step 2: Variable type switch
- You are now in serious
aten
C++ land 🥸 - Purpose of this step is to redirect ops to functions corresponding to a tensor’s
device
- For exampe:
THPVariable_add
→VariableDefault::add
(ataten/src/Aten/TypeDefault.app
, auto-generated)
Step 3: Data type switch
- You are now in kernel land 🚀. Here things are device-specific!
- Purpose of this step is to redirect ops to functions corresponding to a tensor’s
dtype
- By this step, we have reached the kernel proper, it could be at the better side of the town (
native
in C++) or the worse side of the town (TH
in C).
— `VariableDefault::add` → `at::native::add` (at `aten/src/Aten/native/BinaryOps.cpp`)
It’s worth reemphasizing that all of the code, until we got to the kernel, is automatically generated. It’s a bit twisty and turny, so once you have some basic orientation about what’s going on, it is recommended just jumping straight to the kernels.
Writing kernels
There are several steps involved when writing kernels. Here’s a brief overview:
- Begin witherror checking (e.g. make sure the input tensors are the correct dimensions).
- Next, we generally have to allocate the result tensor which we are going to write the output into using
result.resize_(self.sizes());
The_
inresize_
means in-place modification.
3. [This step is not required by some device] At this point, you now should do the second, dtype
dispatch, to jump into a kernel which is specialized per dtype
it operates on. You will use the AT_DISPATCH_ALL_TYPES
macro, as we talked about in “Anatomy of an operator call”
4. Most performant kernels need some sort of parallelization. Implementation is device-specific.
5. Finally, you need to access the data and do the computation you wanted to do!
- If you just want to get a value at some specific location, you should use
TensorAccessor
. - If you’re writing some sort of operator with very regular element access, use
TensorIterator
. You can learn more aboutTensorIterator
here. - For true speed on CPU, use helpers like
binary_kernel_vec
.
On more efficient development workflow
When writing code, there’s nothing more frustrating than having to change a header file and waiting several hours for the necessary builds to occur. Here are a few tips for maximizing the efficiency and speed of your PyTorch code development:
- Don’t edit header: If you edit a header, especially one that is included by many source files (and especially if it is included by CUDA files), expect a very long rebuild. Try to stick to editing cpp files, and edit headers sparingly!
- Don’t test by CI: CI is a very wonderful, but expect to wait an hour or two before you get back signal. If you are working on a change that will require lots of experimentation, spend the time setting up a local development environment.
- Do setup cache: If you are working on the C++ land, setting up CCache can save you a lot of build time. The CONTRIBUTING guide explains how to setup ccache.
Conclusion
In conclusion, PyTorch’s internal architecture is complex but highly efficient. Understanding Tensor Wrappers and Kernels enables you to optimize the computations’ efficiency while ensuring that you can write custom code to suit your project’s needs. We hope that this overview of the PyTorch internals will help you dive deeper into PyTorch development. PyTorch continues to develop and evolve, so be sure to stay updated with additional resources and development guides
Additional Resources:
- PyTorch Official Wiki: how to author a kernel
- PyTorch TensorIterator Internals
- Podcast: TensorIterator
- How to understand Pytorch Source Code?
- A Tour of PyTorch Internals: Part 1
- A Tour of PyTorch Internals: Part 2
- PyTorch — Internal Architecture Tour]
- Podcast: PyTorch Developer Podcast
- PyTorch Wiki
- PyTorch Internals by its author