Bring Your Own Compiler/Optimization in Pytorch

achang
6 min readNov 26, 2022

--

TLDR: code example

This tutorial will demonstrate some methods and some Pytorch features to help you to integrate your optimization passes it into Pytorch. Here you will learn how to improve performance of DNN models with graph transformations and be able to apply those directly to Pytorch framework.

Summary of methods shown here:

  • Swap layers: change the layers in a model
  • Torchscript passes: create and register a optimization pass
  • Torchscript custom compile: gather model subgraphs and compile it
  • Dynamo: lastest way to ingrate different compile backend to pytorch

For this tutorial, we will study BERT model and try optimize it.

Setup

Create a conda environment with pytorch, transformers, functorch and dynamo. This tutorial was tested on pytorch 1.12, functorch 0.2.1, dynamo 1.12 and transformers 2.24.

Profile

First, lets profile BERT to see the operations breakdown using torch.profile. With this, you can also generate a execution trace and export it to json file, which can be visualized with chrome://tracing/

Figure1: Example of torch.profile output for BERT

There we see majority of BERT is matrix multiply operations. And main modules are: BERTSelfAttention and BERTLayer

It is also good to have a visual of the computation graph. So we can export BERT to ONNX and view it with Netron. Try export the sub modules of the model such as BERTSelfAttention to view a simpler graph.

Figure2: Part of BERTSelfAttention with possible optimization ideas highlighted. Graph image cut to fit this post.

There we can get some ideas of possible fusions. Ideas:

  • In blue, fuse constant add, div with previous matmul
  • In red, we can combine the 3 matmul (query, key, value) since they use same input.
  • We can also convert LayerNorm to BatchNorm, so we can fuse BatchNorm with matmul. May degrade accuracy but doesn’t seem too bad in some cases example
  • Fuse gelu or replace with relu inplace with matmul. May degrade accuracy but could retrain to recover it example
  • Majority of operations are matmul, so implementing efficient matmul kernel on a particular device or using special algorithm (winograd) could be beneficial

You can also just use torch.jit.optimize_for_inference.

In this tutorial, we will implement some ideas using different methods and integrate with Pytorch.

Swap modules

This is the simplest method to transform a model. Just swap the layers and then re-train if needed. In this example, we swap module GELU to ReLU with inplace using Pytorch named_children and setattr.

Torchscript: graph rewrite passes

Now lets study how to create functions that fuse layers or perform graph optimizations when running the model.

Torchscript is a way to generate computation graph from a model. This graph is a form of an Intermediate Representation (IR). It is good to read the Torchscript IR and its code. Few important data structures worth taking a look, such as Value and Graph are defined in here. There are several example passes here.

Lets take a look into fuse_relu.cpp, which will combine add with relu and generate add_relu. Note aten contains the list of operators in Pytorch. For example, add_relu is here.

For our BERT case, we want to try creating a pass to fuse add and div. We can do it like here. We use graph rewrite to match pattern and swap then to a fused operation. For example, combines add with constant div to generate aten::addcdiv. Usually, a pass will take in Graph and modify it.

Now, we want to add this pass to be part of Torchscript so that that optimization is applied when we run the Pytorch model. We can register some custom passes like here. We can also compose sequence of passes.

We use pybind to wrap the C++ functions into a python module. Like here for example. This setup.py will use torch.utils.cpp_extension, so that you can build your C++ functions together with Pytorch libraries.

Build and install your new optimization module with: python setup.py install

This will create a python module that will call the C++ functions. For example, we have enable and disable functions. You can see them in action here.

Torchscript: adding custom subgraph compilation:

Now that we know how to fuse operations. But what if we want to implement that fused operator differently or if aten::addcdiv kernel didnt exist. We can fuse operations and create a custom subgraph operation to compile and implement.

Take a look in here where we register a new custom operation and how we implement it. The file fusion_pass.cpp defines a pass that will create a subgraph operator opt::CompilationGroup when it sees operators in canHandle. This special operator will be executed differently from Pytorch. And whatever operations that isn’t in canHandle will be handled by default Pytorch backends. In our example, our new operator is called opt::CompilationGroup and it is composed of matmul with constant division defined in canHandle.

Figure3: Example of subgraph split that will run our custom implementation of matmul + div. Rest operations are run automatically on Pytorch.

In compiler.cpp, the subgraph is compiled and we can generate a custom implementation for a group of operations. In run, there are many elements going on:

  • Stack: this is the input tensors to of the subgraph.
  • Subgraph: it contains the fused operations in this subgraph. eg.: matmul+div.
  • JIT compilation: when run first time, it will do compilation and save it into cache, so that subsequent runs dont need to compile again.
  • Run implementation: run the operation on the inputs and generate output tensor.
  • Return results: you push output tensors to Stack, so that next layers will receive that Stack of tensors as input.

So when you call model forward in Python, it will call this run function. First time, it will compile and save compilation as Compiled_info. Afterwards it will use that cached compilation and go straight to running the operation.

In our example, we combine the constant div with weights of matmul and call the aten::matmul. But there you could have added any implementation of matmul: OpenBLAS, Eigen, etc…

You can also integrate a different hardware backend for a particular set of operations. For example, you made a special matmul with division engine on FPGA and there are some function calls to kick off that matmul_div_fpga. You can use Torchscript to make it part of Pytorch. Then all Pytorch models that has matmul with div will be accelerated by your special FPGA design.

But there are other ways to achieve that in Pytorch. Lets see…

Dynamo

More recently, dynamo allows more easily to integrate different backends into Pytorch. It uses another computation graph that Pytorch produces is: torch.fx

We follow similar approach to Torchscript compile, only difference is that compile will get torch.fx graph that is converted into Graph. We use split_module to partition torch.fx graph into subgraphs. These subgraphs are fed into custom compiler and generates a executable. In our example, we just call Pytorch anyway. But you can add your our operation implementations. Then all you need is to run your model with dynamo.

Others

It is worth noting there are various other methods to enable optimized models runtime using compilers and graph transformations.

For example, TVM is a compiler for ML. They target several devices and it is integrated with various frameworks. They have BYOC, which allows new hardware/compilers to be incorporated into TVM. But going through TVM will be for another post.

Another is functorch. There you can use your own compiler function to compile and run torch.fx graphs. What is more, it creates computation graphs for training too. This means you can also create custom compiler optimizations for accelerating training. Here is a good read about it.

References

This tutorial is based on a collection of information from:

[PyTorch JIT compiler tutorial]

[torch tvm example]

[functorch]

[dynamo]

Notes: Pytorch is one of top ML frameworks and it contains tons of features. This post doesnt try to cover all that is ongoing in Pytorch, but rather just mentions some methods on how to do things. Look for forum for more. Have fun.

--

--