How Activation Checkpointing enables scaling up training deep learning models

PyTorch
PyTorch
Published in
9 min readNov 8, 2023

By Yiftach Beer, Omri Bar

Overview

Activation checkpointing is a technique used for reducing the memory footprint at the cost of more compute. It utilizes the simple observation that we can avoid saving intermediate tensors necessary for backward computation if we just recompute them on demand instead.

Currently there are two implementations of activation checkpointing available in PyTorch, reentrant and non-reentrant. The non-reentrant version was implemented later to address some of the limitations of reentrant checkpoint which are detailed in PyTorch’s official docs. You can pass the use_reentrant flag to specify which version of checkpoint to use. Currently, the use_reentrant flag is optional and the reentrant version is the default. In 2.1 however, not explicitly passing the flag will be deprecated. In a future version of PyTorch, non-reentrant will become the default.

In this post, we first give some background about how PyTorch’s automatic differentiation works in general. Then we explore the new non-reentrant implementation of activation checkpointing and compare it with the earlier reentrant implementation. The implementations presented will be simplified for clarity.

Autograd in PyTorch

Before we dive in, let us briefly review some concepts needed for later.

The basic building block of PyTorch for storing and manipulating data is the tensor. By default, a tensor is not too different from a numpy array with GPU support. When a tensor has its .requires_grad attribute set to True, the autograd engine kicks in.

Every transformation applied to the tensor then creates — along with the resulting tensor — a special object that knows how to compute the transformation’s backward pass for backpropagation. This object can be accessed through the result tensor’s .grad_fn attribute.

The same object is also connected to other similar objects, all serving as nodes in the Directed Acyclic Graph (DAG) called the computational graph. When a new node is created, autograd adds it to the graph by making its .next_functions attribute point to the existing nodes from which it was created.

Let’s focus on a concrete example. In the following code snippet:

a = torch.tensor([2.], requires_grad=True)
b = torch.tensor([3.], requires_grad=True)
c = a + b
d = c.sin()

c and d have nodes corresponding to the backward pass of the add and sine functions, respectively.

a and b, which are tensors that were created directly and not as a part of an operation, are called leaf tensors.

Such nodes have, instead of the regular node, an AccumulateGrad node which has a .variable attribute pointing to their tensor.

A helpful way to think about these is as two tiers — one for tensors and one for backward functions making the computational graph. One tier (bottom in the figure) is made of tensors which are not connected to each other, but may be connected through a .grad_fn attribute to a backward function; and the other tier (top in the figure) is for the backward functions, which are unaware of the tensors — except for the AccumulateGrad special function — and are connected using a .next_functions attribute.

Each operation performed on a tensor with requires_grad creates — along with the resulting tensor — a new node in the computational graph. This behavior can be disabled inside a torch.no_grad() context manager and re-enabled in an inner torch.enable_grad() context manager.

However, not all functions of the tensor class create nodes in the computational graph — for example, torch.detach() copies the tensor without its .grad_fn, as a new tensor disconnected from any computational graph.

Higher level modules in the torch.nn package, such as Linear and MultiheadAttention, do not themselves take part in the computational graph. Instead, when they are called, they simply add the lower-level nodes they are composed of.

For example, consider a Linear block, which is made of multiple operations including matrix multiply with a weight tensor and addition with its bias tensor:

x = torch.tensor([2.])
fc = nn.Linear(1, 1)
y = fc(x)

Grouping the temporary intermediate tensors for simplification, this is how it looks under the hood:

Once the computational graph has been constructed, a call to tensor.backward() — in turn calling torch.autograd.backward() — will recursively compute the gradients up to the leaf nodes where .grad is stored. When this process halts, the role of the computational graph ends and it is discarded (unless retain_graph=True is specified).

The new hook-based non-reentrant variant

The non-reentrant variant of activation checkpointing makes use of autograd’s saved variable hooks mechanism.

Here’s a simple example for how hooks are used:

storage = []

def pack(x):
storage.append(x)
return len(storage) - 1

def unpack(x):
return storage[x]

x = torch.randn(1024, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(pack, unpack):
y = torch.square(x)
y.sum().backward()

Behind the scenes, the square operator saves its input for the backward calculation (as the derivative of f(x)=x² is 2*x). Here, instead of saving the “large” tensor, we only store its (lightweight) index in the graph, and use that index to reconstruct it later. Though in this toy example the actual tensor is also stored (thus not saving any space) this will be a basis for the actual non-reentrant version of activation checkpointing.

With this understanding, let us explore the non-reentrant implementation of activation checkpointing based on saved tensor hooks. It has been simplified for clarity, but the overall structure remains:

class Frame:  # a struct for shared variables
def __init__(self):
self.recomputed = []
self.count = 0


class RecomputationHook(torch.autograd.graph.saved_tensors_hooks):
def __init__(self, frame):
def pack(x):
frame.recomputed.append(x.detach())
return x.detach()

def unpack(X): # is only relevant for more complex scenarios
return x

super().__init__(pack, unpack)


class CheckpointHook(torch.autograd.graph.saved_tensors_hooks):
def __init__(self, frame, run_function, args):
def pack(unused_x):
i = frame.count
frame.count += 1
return i

def unpack(i):
if not frame.recomputed: # only once, while unpacking the first tensor during backward
with RecomputationHook(frame), torch.autograd.enable_grad():
run_function(*args)
res = frame.recomputed[i]
frame.recomputed[i] = None
return res

super().__init__(pack, unpack)

def checkpoint_without_reentrant(run_function, *args):
with CheckpointHook(Frame(), run_function, args):
res = run_function(*args)
return res

When the non-reentrant activation checkpoint is called, the function’s forward pass is run in a CheckpointHook context manager. Under this context manager, any tensor packed and saved for the backward pass is discarded and replaced with a placeholder (here we arbitrarily use its index i).

During the backward pass, the first backward function that tries to access and unpack its saved tensors, triggers the forward() function to be recomputed under a RecomputationHook, which intercepts any tensors saved to store them in the recomputed list (detached from the computational graph to avoid reference cycles). It is important to note that the whole mechanism relies on the recomputed tensors being accessed in the same order in both forward and backward. To make sure that is the case, the real implementation also contains the code to save/restore the global state (e.g. preserving RNG states, which is important to ensure that modules such as Dropout produce the same output in both calls to run_function).

Much of the code not shown here deals with handling more complex cases — in the actual code, each of the variables has its tensor saved once per graph, and there’s also an early stopping mechanism to minimize unnecessary computations. The overall structure, however, is the same.

What’s New

Here are some of the the scenarios that the new version of non-reentrant activation checkpointing in the 2.1 release will support:

  1. Nested checkpointing — calling another checkpointed function from within a checkpointed function. This feature would allow the user to make an even more extreme trade off between memory and compute, potentially reducing theoretical minimum even further to O(log(n)) (from O(sqrt(n)) in the non-nested case):
def inner1(x):
...

def inner2(x)
...

def outer(x)
y = checkpoint(inner1, x)
z = checkpoint(inner2, y)
return z

a = torch.ones(1, requires_grad=True)
out = checkpoint(outer, a)
out.backward()

2. Support for calling .grad()/.backward() within checkpointed functions — this is useful for higher-order gradient computation.

3. Improved checks for non-determinism and improved debuggability. Recall that an important assumption non-reentrant checkpoint makes is that the original and recomputed forward calls must save tensors for the backward in the same exact order. Beginning in 2.1, basic tensor metadata will be stored and checked to help validate that this is the case. Furthermore, if any checks for non-determinism fail, users can run checkpoint with debug=True which can provide traces of the ops executed during the original and recomputed runs as well as stack traces at the point in time those ops were called in order to help the user pinpoint where the non-determinism occurred.

4. Improved memory savings when retain_graph is specified

For more information, please consult the docs. There’s also a comprehensive comment inside the code that details the various scenarios handled, and consulting it might be necessary to get a complete understanding. This design doc also contains some information about why choices were made and what new scenarios are supported.

The reentrant variant

An earlier implementation, called the reentrant variant, does not utilize saved variable hooks but instead uses a custom autograd Function, modifying the computational graph.

As the official implementation contains many details, we focus on a simplified one instead:

class Checkpoint(torch.autograd.Function):

@staticmethod
def forward(ctx, run_function, input):
ctx.run_function = run_function
ctx.save_for_backward(input)

with torch.no_grad():
output = run_function(input)

return output

@staticmethod
def backward(ctx, output_grad):
run_function = ctx.run_function
input = ctx.saved_tensors

detached_input = input.detach()
detached_input.requires_grad_(input.requires_grad)

with torch.enable_grad():
output = run_function(detached_input)

torch.autograd.backward(output, output_grad)

return None, input.grad

In the forward pass, we calculate the output of the given run_function module on the given input. Since the call is inside a with torch.no_grad() context manager, no intermediate nodes and no backward nodes are created — the output of the operation is directly connected through a CheckpointBackward to the backward function of the inputs.

In the backward pass, we first load the objects passed to forward. Then, we detach the input and re-run the forward computation, this time in a torch.enable_grad() context manager to allow building a computational graph up to the output, which we backpropagate through to update the parameters that live inside this block. Note how we have to return two values, one for run_function and one for input, but the former is None as a module is not a differentiable tensor.

The dynamics are illustrated below:

What is noteworthy is how the gradient calculations are not part of the main computational graph anymore — each time, a “mini computational graph” is constructed in which the actual gradient computations happen, whereas the original graph just coordinates and forwards gradients.

The inner graph is not completely independent — we do need to pass grad_output to it, and return the input.grad it computed to the rest of the outer graph. But all the parameters of the current block, which hide behind run_function, get their .grad attribute populated inside that inner backward() call.

To make things simple, we left out some details such as saving/restoring global state for the second forward pass, and allowing a varying number of arguments and outputs, not all of which are tensors.

This implementation is called the reentrant variant as it uses a nested backward pass, called “reentrant backward” in PyTorch terminology. While using a nested backward pass might seem simple, in practice this implementation has limitations, e.g. it does not work well with DDP and FSDP in some cases.

Usage

Luckily, the complexities of both design are all wrapped in a simple-to-use API — the new implementation to use is specified using a use_reentrant flag, where using False (i.e. the new implementation) will become the default in a future version:

from torch.utils.checkpoint import checkpoint

checkpoint(run_function, args, use_reentrant=False)

You may refer to the docs for additional arguments and options.

We hope that after understanding these details, you will be able to use activation checkpointing more effectively, customize it to your needs and contribute to the PyTorch repository.

Acknowledgements

We would like to thank Geeta Chauhan from Meta AI/ML and the PyTorch team for their assistance in preparing this post.

--

--

PyTorch
PyTorch

PyTorch is an open source machine learning platform that provides a seamless path from research prototyping to production deployment.