torch.compile, explained

Kaichao You
PyTorch
Published in
4 min readOct 26, 2023

--

Have you ever felt overwhelmed by the complexities of torch.compile? Diving into its workings can feel like black magic, with bytecode and Python internal details that many users fail to understand, hindering them from understanding and debugging torch.compile.

I am excited to introduce depyf, a new tool to pull out all the artifacts of torch.compile, and to decompile all the bytecode into source code so that every user understands it.

Note: please install the package by pip install depyf before going on to the example below.

Example usage

import torch
+ @torch.compile(backend="eager")
- @torch.compile
def toy_example(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b

for _ in range(100):
toy_example(torch.randn(10), torch.randn(10))

+ from depyf.explain import dump_src
+ src = dump_src(toy_example)
+ with open("explained_code.py", "w") as f:
+ f.write(src)

It’s this simple: switch the backend to “eager”, and run the dump_src function to pull out all the artifacts from torch.compile.

Note: (+)means to add new lines to your code, (-)means to remove old lines from your code.

In the dumped explained_code.py file, you can see something like below:

def guard_2(L):
return (___guarded_code.valid) \
and (___check_global_state()) \
and (hasattr(L['b'], '_dynamo_dynamic_indices') == False) \
and (hasattr(L['x'], '_dynamo_dynamic_indices') == False) \
and (utils_device.CURRENT_DEVICE == None) \
and (___skip_backend_check() or ___current_backend() == ___lookup_backend(5096739488)) \
and (___check_tensors(L['b'], L['x'], tensor_check_names=tensor_check_names))

def __compiled_fn_4(L_b_ : torch.Tensor, L_x_ : torch.Tensor):
l_b_ = L_b_
l_x_ = L_x_
mul = l_x_ * l_b_; l_x_ = l_b_ = None
return (mul,)


def compiled_code_2(b, x):
return __compiled_fn_4(b, x)[0]


def __resume_at_38_2(b, x):
# Note: if there is a compiled version below, this function might well not be executed directly. Please check the compiled version if possible.
return x * b

def compiled___resume_at_38_2(b, x):
L = {"b": b, "x": x}
if guard_2(L):
return compiled_code_2(b, x)
# Note: this function might well not be executed directly. It might well be compiled again, i.e. adding one more guards and compiled code.
return __resume_at_38_2(b, x)

#============ end of __resume_at_38_2 ============#

def guard_1(L):
return (___guarded_code.valid) \
and (___check_global_state()) \
and (hasattr(L['b'], '_dynamo_dynamic_indices') == False) \
and (hasattr(L['x'], '_dynamo_dynamic_indices') == False) \
and (utils_device.CURRENT_DEVICE == None) \
and (___skip_backend_check() or ___current_backend() == ___lookup_backend(5096739488)) \
and (___check_tensors(L['b'], L['x'], tensor_check_names=tensor_check_names))

def __compiled_fn_3(L_b_ : torch.Tensor, L_x_ : torch.Tensor):
l_b_ = L_b_
l_x_ = L_x_
b = l_b_ * -1; l_b_ = None
mul_1 = l_x_ * b; l_x_ = b = None
return (mul_1,)


def compiled_code_1(b, x):
return __compiled_fn_3(b, x)[0]


def __resume_at_30_1(b, x):
# Note: if there is a compiled version below, this function might well not be executed directly. Please check the compiled version if possible.
b = b * -1
return x * b

def compiled___resume_at_30_1(b, x):
L = {"b": b, "x": x}
if guard_1(L):
return compiled_code_1(b, x)
# Note: this function might well not be executed directly. It might well be compiled again, i.e. adding one more guards and compiled code.
return __resume_at_30_1(b, x)

#============ end of __resume_at_30_1 ============#

def guard_0(L):
return (___guarded_code.valid) \
and (___check_global_state()) \
and (hasattr(L['a'], '_dynamo_dynamic_indices') == False) \
and (hasattr(L['b'], '_dynamo_dynamic_indices') == False) \
and (utils_device.CURRENT_DEVICE == None) \
and (___skip_backend_check() or ___current_backend() == ___lookup_backend(5096739488)) \
and (___check_tensors(L['a'], L['b'], tensor_check_names=tensor_check_names))

def __compiled_fn_0(L_a_ : torch.Tensor, L_b_ : torch.Tensor):
l_a_ = L_a_
l_b_ = L_b_
abs_1 = torch.abs(l_a_)
add = abs_1 + 1; abs_1 = None
x = l_a_ / add; l_a_ = add = None
sum_1 = l_b_.sum(); l_b_ = None
lt = sum_1 < 0; sum_1 = None
return (x, lt)


def compiled_code_0(a, b):
__temp_29 = __compiled_fn_0(a, b)
x = __temp_29[0]
if __temp_29[1]:
return __resume_at_30_1(b, x)
return __resume_at_38_2(b, x)


def toy_example(a, b):
# Note: if there is a compiled version below, this function might well not be executed directly. Please check the compiled version if possible.
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b

def compiled_toy_example(a, b):
L = {"a": a, "b": b}
if guard_0(L):
return compiled_code_0(a, b)
# Note: this function might well not be executed directly. It might well be compiled again, i.e. adding one more guards and compiled code.
return toy_example(a, b)

#============ end of toy_example ============#

You can explore the code with your favorite IDE. Start from the toy_example function, and pay attention to the compiled_toy_example function below, walk through all the details of guards/compiled code/compiled subgraph/resume functions. It’s all in readable source code format!

Hopefully, by using this package, everyone can understand torch.compile now! The mental model is shown in the above flowchart.

For more advanced usage, please refer to the github repository depyf.

--

--