A Trip to Kernels: Understanding PyTorch’s Internal Architecture

Huan Xu
5 min readJul 8, 2023

--

If you’re here, you know that PyTorch is one of the most popular libraries among deep learning practitioners. It is highly efficient, and its flexibility makes it easy to write custom code when algorithms lack certain functionalities required in your projects.

However, working with PyTorch’s codebase can be intimidating, especially if you’re used to working solely with high-level abstractions. In this blog post, we will dive into PyTorch’s internal architecture. Specifically, we’ll explore two critical components that facilitate efficient computations: Tensor Wrappers and Kernels.

We will cover:

  • Tensor Wrapper (e.g. batch tensors)
  • Anatomy of an operator call
  • Writing kernels
  • On more efficient development workflow
  • Additional resources

Let’s begin 🚀

Tensor Wrapper (e.g. batch tensors)

PyTorch Tensor has well-designed metadata that decouples logical memory layout from physical memory and provides primitives for operator dyamic dispatch from the Python frontend to the high-performance C++ backend. Every tensor consists of a device, layout, and dtype metadata.

  • device allows per-device ops (run in kernels) dispatch for better efficiency.
  • layout allows sharing the same hardware-level memory allocation for multiple views, decoupling logical tensor layout from physical memory layout. When accessing the actual physical memory is required, TensorAccessor will leverage layout (e.g. stride) to access tensor elements. You can learn more about TensorAccessor from this blog or this awesome PyTorch internals podcast.
  • dtype allows per-dtype ops (run in kernesl) dispatch for better efficiency via the AT_DISPATCH_ALL_TYPES macro in aten kernels.
Tensor layout decouples logical tensor layout from physical memory layout

Anatomy of an operator call

When developing a custom operator with PyTorch, multiple steps are involved before any computation occurs. This process involves Python argument parsing, variable type switching, data type switching, and eventually, kernel dispatch.

📧 Note: The code to pass the dispatch is auto-generated from aten/src/ATen/native/native_functions.yaml

Step 1: Python argument parsing

  • You are still in python land 🥳
  • Purpose of this step is to pass Python argument to C++ bindings via torch/csrc
  • For example: torch.addTHPVariable_add (at torch._C.VariableFunctions, auto-generated)

Step 2: Variable type switch

  • You are now in serious aten C++ land 🥸
  • Purpose of this step is to redirect ops to functions corresponding to a tensor’s device
  • For exampe: THPVariable_addVariableDefault::add (at aten/src/Aten/TypeDefault.app, auto-generated)

Step 3: Data type switch

  • You are now in kernel land 🚀. Here things are device-specific!
  • Purpose of this step is to redirect ops to functions corresponding to a tensor’s dtype
  • By this step, we have reached the kernel proper, it could be at the better side of the town (native in C++) or the worse side of the town (THin C).
    — `VariableDefault::add` → `at::native::add` (at `aten/src/Aten/native/BinaryOps.cpp`)

It’s worth reemphasizing that all of the code, until we got to the kernel, is automatically generated. It’s a bit twisty and turny, so once you have some basic orientation about what’s going on, it is recommended just jumping straight to the kernels.

Don’t get lost in details

Writing kernels

There are several steps involved when writing kernels. Here’s a brief overview:

  1. Begin witherror checking (e.g. make sure the input tensors are the correct dimensions).
  2. Next, we generally have to allocate the result tensor which we are going to write the output into using result.resize_(self.sizes());The _ in resize_ means in-place modification.

3. [This step is not required by some device] At this point, you now should do the second, dtype dispatch, to jump into a kernel which is specialized per dtype it operates on. You will use the AT_DISPATCH_ALL_TYPES macro, as we talked about in “Anatomy of an operator call”

4. Most performant kernels need some sort of parallelization. Implementation is device-specific.

5. Finally, you need to access the data and do the computation you wanted to do!

  • If you just want to get a value at some specific location, you should use TensorAccessor.
  • If you’re writing some sort of operator with very regular element access, use TensorIterator. You can learn more about TensorIterator here.
  • For true speed on CPU, use helpers like binary_kernel_vec.

On more efficient development workflow

When writing code, there’s nothing more frustrating than having to change a header file and waiting several hours for the necessary builds to occur. Here are a few tips for maximizing the efficiency and speed of your PyTorch code development:

  • Don’t edit header: If you edit a header, especially one that is included by many source files (and especially if it is included by CUDA files), expect a very long rebuild. Try to stick to editing cpp files, and edit headers sparingly!
  • Don’t test by CI: CI is a very wonderful, but expect to wait an hour or two before you get back signal. If you are working on a change that will require lots of experimentation, spend the time setting up a local development environment.
  • Do setup cache: If you are working on the C++ land, setting up CCache can save you a lot of build time. The CONTRIBUTING guide explains how to setup ccache.

Conclusion

In conclusion, PyTorch’s internal architecture is complex but highly efficient. Understanding Tensor Wrappers and Kernels enables you to optimize the computations’ efficiency while ensuring that you can write custom code to suit your project’s needs. We hope that this overview of the PyTorch internals will help you dive deeper into PyTorch development. PyTorch continues to develop and evolve, so be sure to stay updated with additional resources and development guides

Additional Resources:

--

--

Huan Xu

MSCS@GaTech. Interested in accessible ML inference. Building www.baynana.co, an AI-powered resume supercharger.