Point Net from Scratch
This is the second part of the Point Net Series:
- An Intuitive Introduction to Point Net
- Point Net from Scratch
- Point Net for Classification
- Point Net for Semantic Segmentation
In this article we will learn how to code Point Net from scratch in PyTorch. Point Net is a flexible architecture that allows for classification or semantic segmentation. If you are not familiar with Point Net please see this article. If you would just like to code it, please proceed forward, we will break down Point Net and try to understand it piece by piece. The code for this article is stored in this repository.
Point Net Overview
Point Net is a novel architecture that is able to directly consume point clouds, the architecture is outlined in figure 1. Don’t let this overwhelm you, we will break this down piece by piece. Let’s start with the input, Point Net takes a 3D point cloud of n points (Nx3). A point cloud doesn’t have any order, meaning that we can shuffle the data all we want to and we still get the same point cloud. However a point cloud does have an orientation and position; it should be noted that no matter how much you change orientation and position, the overall structures of the point cloud will remain the same. If you are familiar with image processing and Convolutional Neural Networks (CNNs), you may know that images have orientations. And for a CNN to recognize an image in many orientations, it typically must be trained on many orientations. Point Net doesn’t need to be trained with many orientations, in fact it can learn to recognize point clouds regardless of their orientations. How does it do this? The answer is in the T-nets. Let’s take a code first approach to understanding the T-nets! Once again the code for this is located here.
Coding the T-nets
First let’s understand the architecture of the T-net, the architecture is below in figure 2.
The first portion is a series of shared Multilayer Perceptrons (MLPs) followed by a Max Pool and a few Fully Connected (Linear/Dense) Layers. The only tricky part is the last part. After the final FC layer, we reshape the output into a matrix and add the identity matrix to it. Remember what the T-net does, it learns a transformation that will be applied to each point in the point cloud. The identity matrix is an orthogonal matrix that will return the same point that it’s multiplied against. Here, we are essentially initializing the output of the T-net with the identity matrix, and we do this for stability. If we were to initialize to zero, then we would set all the points to zero, if we were to use a random initialization, then we could disrupt the structure of the point cloud.
We will need to write code for two T-nets one for the input space and the other for 64 dimensional space; we will write a single PyTorch class that can handle both of these. Import the necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
Next let’s set up the class
class Tnet(nn.Module):
''' T-Net learns a Transformation matrix with a specified dimension '''
def __init__(self, dim, num_points=2500):
super(Tnet, self).__init__()
# dimensions for transform matrix
self.dim = dim
self.conv1 = nn.Conv1d(dim, 64, kernel_size=1)
self.conv2 = nn.Conv1d(64, 128, kernel_size=1)
self.conv3 = nn.Conv1d(128, 1024, kernel_size=1)
self.linear1 = nn.Linear(1024, 512)
self.linear2 = nn.Linear(512, 256)
self.linear3 = nn.Linear(256, dim**2)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.bn4 = nn.BatchNorm1d(512)
self.bn5 = nn.BatchNorm1d(256)
self.max_pool = nn.MaxPool1d(kernel_size=num_points)
def forward(self, x):
bs = x.shape[0]
# pass through shared MLP layers (conv1d)
x = self.bn1(F.relu(self.conv1(x)))
x = self.bn2(F.relu(self.conv2(x)))
x = self.bn3(F.relu(self.conv3(x)))
# max pool over num points
x = self.max_pool(x).view(bs, -1)
# pass through MLP
x = self.bn4(F.relu(self.linear1(x)))
x = self.bn5(F.relu(self.linear2(x)))
x = self.linear3(x)
# initialize identity matrix
iden = torch.eye(self.dim, requires_grad=True).repeat(bs, 1, 1)
if x.is_cuda:
iden = iden.cuda()
x = x.view(-1, self.dim, self.dim) + iden
return x
The shared MLPs are implemented as 1D convolutions, since convolutions naturally allow for weight sharing they make shared MLPs easy to implement. We can see that ‘dim’ handles the learned matrix dimension and final Linear layer scales up (or down) to that dimension. We also need to pass the desired number of points so that the 1D Max Pool will work. Now let’s move on to coding the main portion of Point Net, which I call the backbone.
Coding the Backbone
The Backbone of Point Net ties together the T-nets, it is essentially the classification portion in figure 1, expect without the classification head.
The backbone either returns the Global Features (for classification) or a Concatenation of the Local and Global Features (for segmentation). The number of Global Features used in the paper was 1024, but we can set this as a hyperparameter if we wish. NOTE: I noticed 1024 gave the best performance.
The backbone class includes an argument ‘local_feat’ that allows it’s forward() function to return either global or concatenated local and global features when set to True. This will determine whether the backbone is used for segmentation or classification. Part of the class initialization is shown below, once again the full code is located here.
class PointNetBackbone(nn.Module):
def __init__(self, num_points=2500, num_global_feats=1024, local_feat=True):
''' Initializers:
num_points - number of points in point cloud
num_global_feats - number of Global Features for the main
Max Pooling layer
local_feat - if True, forward() returns the concatenation
of the local and global features
'''
super(PointNetBackbone, self).__init__()
The Max Pool function extracts the dominant global features. In our implementation we can visualize them by getting the Max Pool function to return the indices of these features, by setting ‘return_indices=True’.
self.max_pool = nn.MaxPool1d(kernel_size=num_points, return_indices=True)
It is possible that PyTorch will throw a user warning when this is set, but you can suppress it if you would like by adding this in your training script.
# supress pytorch user warnings
import warnings
warnings.filterwarnings("ignore")
We refer to the global feature indices as the critical indices, since they index the points that are critical to the overall shape and structure of the point cloud. Figure 4 shows an example of raw point clouds (sets) and their critical point sets. You can see that even though the critical point sets are sparse they still retain the overall structure of their corresponding point clouds. This indicates that Point Net is actually learning about the underlying structures of each class.
Coding the Classification Head
Now let’s move on to the classification head, which is just another series of shared MLPs that will learn output scores for each class. The input is simply the learned Global Features.
Here’s the class initialization, full code located here.
class PointNetClassHead(nn.Module):
'''' Classification Head '''
def __init__(self, num_points=2500, num_global_feats=1024, k=2):
super(PointNetClassHead, self).__init__()
# get the backbone (only need global features for classification)
self.backbone = PointNetBackbone(num_points, num_global_feats, local_feat=False)
And now here’s the forward function.
def forward(self, x):
# get global features
x, crit_idxs, A_feat = self.backbone(x)
x = self.bn1(F.relu(self.linear1(x)))
x = self.bn2(F.relu(self.linear2(x)))
x = self.dropout(x)
x = self.linear3(x)
# return logits
return x, crit_idxs, A_feat
In the Classification head, the backbone does most of the work by extracting the global features (returned as x). It also provides us the critical indices and the feature transformation matrix. Returning the critical indices is optional and may be removed from your implementation if you wish.
Coding the Segmentation head
Now we will code the segmentation head of Point Net. This head takes in the concatenation of the learned local and global features providing it a rich representation of the input Point Cloud. The architecture is shown below in figure 6.
At the input, the global features are repeated n times when concatenated with the local features. The architecture is straightforward, a series of shared MLPs that maintain the original n dimensions, and the final layer projects each point to m possible classes. The first part of the segmentation class init function is shown below. The code is located here.
class PointNetSegHead(nn.Module):
''' Segmentation Head '''
def __init__(self, num_points=2500, num_global_feats=1024, m=2):
super(PointNetSegHead, self).__init__()
self.num_points = num_points
self.m = m
# get the backbone
self.backbone = PointNetBackbone(num_points, num_global_feats, local_feat=True)
Training Point Net
Point net can be trained for either Classification or Semantic Segmentation, but not both, since the architecture between each version is different. In the next tutorial, we will learn how to train Point Net for classification and we will also see how to visualize the Critical Sets.
References
Charles, R. Q., Su, H., Kaichun, M., & Guibas, L. J. (2017). PointNet: Deep Learning on point sets for 3D classification and segmentation. 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). https://doi.org/10.1109/cvpr.2017.16