A Taste of PyTorch C++ frontend API

Venkata Chintapalli
PyTorch
Published in
4 min readMay 2, 2020

Introduction

One major enhancement of the recently released PyTorch 1.5 is a stable C++ frontend API parity with Python¹. C++ frontend API works well with Low Latency Systems, Highly Multi-threaded Environments, Existing C++ code bases, you can check the motivation and use cases of C++ frontend here³. I want get a taste of the PyTorch C++ frontend API by creating a small example. So I took a simple two layer neural network example from Learning PyTorch with Examples². The rest of this post details the steps to convert the two layer neural network using Python frontend API example to work with the C++ frontend API. The complete code example with steps to run the code are detailed in the Github repo⁴.

Update: Performance measure plot added 05/04/2020.

There are three main components in PyTorch, Tensors, Automatic differentiation (autograd), and nn (neural network)module.

Tensors

Python is a dynamic language and it doesn’t need to declare data types for variables. Where as C++ is a statically typed language with type safety for creating compiled and optimized code which helps in creating fast and demanding applications. So the first step to do in converting the Python code is to start defining the data types. Modern C++ also defines the auto type specifier that helps the compiler to infer the type during compile time.

Let’s start by creating a random input and output tensor data for training the neural network to map x to y, i.e., to learn the function f that defines y = f(x).

# Python
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

Tensors in C++ uses scope resolution operator :: instead of dot notation.

// C++
torch::Tensor x = torch::rand({N, D_in});
torch::Tensor y = torch::rand({N, D_out});

Two Layer Neural Network Model

In Python, neural networks are derived from the reusable base module torch.nn.Module. Submodules are automatically detected and registered using wrappers when they are assigned as an attribute of a module. Here is a simple Two layer network using Python interface.

import torch# Python neural network model defined as a classclass TwoLayerNet(torch.nn.Module):  def __init__(self, D_in, H, D_out):
super(TwoLayerNet, self).__init__()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out)
def forward(self, x):
h_relu = self.linear1(x).clamp(min=0)
y_pred = self.linear2(h_relu)
return y_pred

Neural network modules are implemented as structs in C++. There are two ways to define the modules using either value semantics or reference semantics. First, we will see an example of creating a module with value semantics.

#include <torch/torch.h>// C++ neural network model defined with value semanticsstruct TwoLayerNet : torch::nn::Module {
// constructor with submodules registered in the initializer list
TwoLayerNet(int64_t D_in, int64_t D_out, int64_t H) :
linear1(register_module("linear1", torch::nn::Linear(D_in, H))),
linear2(register_module("linear2", torch::nn::Linear(H, D_out)))
{}
torch::Tensor forward(torch::Tensor x) {
x = torch::relu(linear1->forward(x));
x = linear2->forward(x);
return x;
}
torch::nn::Linear linear1;
torch::nn::Linear linear2;
};
// Usage: access the object using dot . operator
TwoLayerNet model(D_in, D_out, H);
model.to(device);
model.forward(x);
model.parameters();

To register submodules in C++, register_module() method defined in the constructors initializer list makes it possible to recursively access the module tree’s parameters.

/// Registers a submodule with this `Module`.
///
/// This method deals with `ModuleHolder`s.
///
/// Registering a module makes it available to methods such as
/// `modules()`, `clone()` or `to()`.
///
/// \rst
/// .. code-block:: cpp
///
/// MyModule::MyModule() {
/// submodule_ = register_module("linear", torch::nn::Linear(3, 4));
/// }
/// \endrst
template <typename ModuleType>
std::shared_ptr<ModuleType> register_module(
std::string name,
ModuleHolder<ModuleType> module_holder);

Defining C++ modules using the reference semantics with std::shared_ptr modular holder API is recommended. TORCH_MODULE(TwoLayerNet) macro defines the TwoLayerNet class. This “generated” class is effectively a wrapper over a std::shared_ptr<TwoLayerNetImpl>. This simplifies the coding instead of writing std::make_shared<TwoLayerNetImpl>(1000, 10, 64) you can writeTwoLayerNet(1000, 10, 64). Reference semantics are the recommended way of defining modules with the C++ frontend API.

// C++ neural network model defined with reference semanticsstruct TwoLayerNetImpl : torch::nn::Module {
// module holders are assigned in the constructor
TwoLayerNetImpl(int64_t D_in, int64_t D_out, int64_t H) :
linear1(D_in, H), linear2(H, D_out) {
register_module("linear1", linear1);
register_module("linear2", linear2);
}
... // forward() method

torch::nn::Linear linear1{nullptr}; //construct an empty holder
torch::nn::Linear linear2{nullptr}; //construct an empty holder
};
TORCH_MODULE(TwoLayerNet);
// Usage: access the object using arrow -> operator
TwoLayerNet model(D_in, D_out, H);
model->to(device);
model->forward(x);
model->parameters();

Optimizer

The optimizer code looks very similar in Python and C++.

# Python SGD optimizer

learning_rate = 1e-4
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.5)
# Zero the gradients before running the backward pass
optimizer.zero_grad()
# Update weights
optimizer.step()

In C++:

// C++ SGD optimizerfloat_t learning_rate = 1e-4;
torch::optim::SGD optimizer(
model.parameters(),
torch::optim::SGDOptions(learning_rate).momentum(0.5));
// Zero the gradients before running the backward pass
optimizer.zero_grad()
// Update the weights
optimizer.step()

2x+ Performance Improvement

I have plotted the training time for both Python and C++ varying the number of epochs from 500 to 5000 on Ubuntu 18.04 LTS Intel IvyBridge computer. I see C++ model training is more 2 times faster compared to the Python model.

To conclude, PyTorch maintains parity between Python and C++ frontend interface. C++ frontend follows the design and ergonomics of Python frontend in most cases. Python and C++ code for this simple neural network example is provided in the Github repo.

References

  1. https://pytorch.org/blog/pytorch-1-dot-5-released-with-new-and-updated-apis/
  2. Justin Johnson. Learning PyTorch with Examples.
  3. Using the PyTorch C++ Frontend
  4. Code examples: https://github.com/venkatacrc/PyTorchCppFrontEnd

--

--

Venkata Chintapalli
PyTorch
Writer for

Machine Learning, Hardware, and High-Performance infrastructure enthusiast. Holds MS in Machine Learning GeorgiaTech and MTech in Electronics Design from IISc.