Point Transformer: Explanation and PyTorch Code

Heejun Park
7 min readJun 2, 2024

--

Today I will talk about Point Transformer and its code implemented in PyTorch. The code is not the official code, it is created by me. The paper was announced at ICCV 2021. Below is the link to the paper.

The authors of this paper are shown below.

1. What is Point Transformer?

Before we begin, I will talk tell you a brief overview of PT(Point Transformer).

PT is a 3D point cloud processing network that utilizes ‘Self-Attention’.

PT can perform Semantic Segmentation, Part Segmentation and Object Classification of 3D point clouds.

The transformer architecture is suitable for processing point cloud. Why so? Because the layer of the transformer is invariant to the permutation of the point cloud.

2. Point Transformer’s Architecture

In this section, I will talk about the design of the PT.

The image below is the network architecture of PT. There are two different types of network in the image.

The network above is for the segmentation tasks(Semantic Segmentation and Part Segmentation). The segmentation network has a encoder-decoder structure (U-Net). It has 5 stages each in the encoder and decoder part.

In the encoder part, the downsampling rate is 4. And the number of feature dimension increases by 2 every time it goes through a stage. In the decoder part, the upsampling rate is 4. The feature dimension is decreased by 2 every time it goes through each stage.

The network below is for the object classification task.

N denotes the number of points that goes into the network.

Click on the image to enlarge it.

The network uses 3 different types of blocks.

  1. Point Transformer Block (Yellow)
  2. Transition Down Block (Blue)
  3. Transition Up Block (Green)

I will explain about each of the blocks specifically.

2.A. Point Transformer Block

The image above depicts the design of the PT block. PT block is responsible for performing the actual self-attention between the points.

The input is (x, p). Where x is the feature vector of individual point. p is the 3D coordinate of the corresponding feature vector. The output is (y, p). Where y is the new feature vector and p is the corresponding 3D coordinate of the feature vector.

It is composed of a 2 linear layers and 1 PT layer. The purpose of the linear layers is to reduce the dimension of the feature vectors to boost the computation speed. The PT layer is used to perform self-attention inside the local region.

Below is the code for the PT block.

class PointTransformerBlock(nn.Module):
def __init__(self, in_features, out_features, position_dim):
super(PointTransformerBlock, self).__init__()
self.fc1 = nn.Linear(in_features, out_features)
self.point_transformer_layer = PointTransformerLayer(out_features, position_dim)
self.fc2 = nn.Linear(out_features, out_features)

def forward(self, x, p):
residual = x # Save

x = self.fc1(x) # Apply the first linear layer
x = self.point_transformer_layer(residual, p) # Apply the Point Transformer layer
y = self.fc2(x) # Apply the second linear layer

y += residual # Add residual connection

return y, p

Simple, isn’t it?

2.A.1. Point Transformer Layer

Now I will explain about the PT layer. This is how the PT layer is defined.

The purpose of the PT layer is to perform self-attention in the local region of the point cloud.

The input of the PT layer is (x, p). x is the feature vector. p is the 3D coordinate of the feature vector.

The output of the PT layer is (y, p). y is the transformed feature vector. And p is the 3D coordinate of the feature vector.

There are two types of branches (look at the equation)

  1. Attention Generation Branch
  2. Feature Transformation Branch

The results of the two branches are combined using , which is a hadamard product(pointwise multiplication).

φ, ψ, α: Pointwise Feature Transformation Matrix (linear layers).

ρ: this is a Sotmax Function used for normalization.

X(i): the subset of the features around feature xi.

xj: features inside the X(i) region.

δ: the position embedding (I will explain this later on)

γ: an mlp with 2 linear layers and 1 ReLU Function.

Below is the code of the PT layer.

class PointTransformerLayer(nn.Module):
def __init__(self, feature_dim, position_dim):
super(PointTransformerLayer, self).__init__()
self.gamma_mlp = GammaMLP(feature_dim, feature_dim, feature_dim)
self.phi = nn.Linear(in_features, out_features)
self.psi = nn.Linear(in_features, out_features)
self.alpha = nn.Linear(in_features, out_features)
self.position_encoder = PositionEncoder(position_dim, feature_dim, feature_dim)
self.rho = nn.Softmax(dim=-1)

def forward(self, x, p):
# calculate position embedding
N = x.size(0)
delta_p = p.unsqueeze(1) - p.unsqueeze(2)
delta_p = delta_p.view(-1, delta_p.size(-1))
position_encoding = self.position_encoder(delta_p)
position_encoding = position_encoding.view(N, N, -1)

# calculate attention generation branch
phi_x = self.phi(x).unsqueeze(1)
psi_x = self.psi(x).unsqueeze(0)
gamma_output = self.gamma_mlp(phi_x - psi_x + position_encoding)
attention_generation_branch = self.rho(gamma_output)

# calculate feature transformation branch
feature_transformation_branch = self.alpha(x).unsqueeze(0) + position_encoding

y = torch.sum(attention_generation_branch * feature_transformation_branch, dim=1)

return y

2.A.2. Position Embedding

Now I will talk about how position embedding is performed. Below is the equation.

pi and pj are the 3D coordinates of the points. θ is an MLP composed of 2 Linear Layers and 1 ReLU funtion.

Unlike ViT, which uses triangular function, why does PT use the 3D coordinates only to generate the embedding vector? It’s because point cloud dataset itself is a 3D coordinate which can be directly used to compute the position embedding.

Position Encoding is added to both Attention Generation Branch and Feature Transformation Branch individually.

Below is the code of the position embedding.

class PositionEncoder(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(PositionEncoderMLP, self).__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, output_dim)

def forward(self, p_i, p_j):
x = p_i - p_j
x = F.relu(self.fc1(x))
x = self.fc2(x)

return x

2.B. Transition Down Block

Below is how the TD(transition down) block looks like.

The purpose of using the TD block is to reduce the number of features.

The input is (x, p1) and the output is (y, p2).

Below is how the TD block operates.

  1. Use ‘Farthest Point Sampling’ algorithm to reduce the number of points from p1 to p2.
  2. Use kNN algorithm to group the points into a subset of points. (k = 16)
  3. Each feature inside the subset goes through an MLP layer individually. This MLP layer consists of a linear layer, Batch normalization layer and ReLU function.
  4. Max Pooling within the local subset to obtain p2.

2.C. Transition Up Block

Below is how the TU(transition up) block looks like.

The TU block is paired to the TD block in the encoder part.

The purpose of TD block is to upsample features from p1 to p2.

This is how TD block operates.

  1. Each input feature goes through a linear layer, Batch Normalization layer and a ReLU function.
  2. Interpolation is conducted to upsample features. (Trilinear interpolation is used)
  3. Add the interpolated result with the output of the paired TD block from the encoder.

3. Experiments

3.A. Settings

Only one Quadro RTX 6000 GPU was used.

And the table below shows the inference time and the memory required to process different number of points during the semantic segmentation task.

Table Created by me :)

3.B. Evaluation Metrices

3.B.1. mIoU

3.B.2. mAcc

3.B.3. OA

3.C. Dataset

3.C.1. Semantic segmentation

S3DIS

3.C.2. ObjectClassification

ModelNet40

3.C.3. Part Segmentation

ShapeNetPart

3.D. Semantic Segmentation Result

Results on Area 5 of S3DIS

3.E. Object Classification

Classification result on ModelNet40

3.F. Part Segmentation

Result on ShapeNetPart

3.G. Ablation Experiments

3.G.1. Number of neighbors k

3.G.2. Softmax Regularization

ρ is the softmax function in the PT layer.

3.G.3. Position Encoding

4. Discussion

4.A. Position Embedding

It is very interesting how the PT model uses the 3D coordinates of the point cloud to compute the position embedding during the self-attention process. If you think about it, it totally makes sense why it is doing that.

--

--

Heejun Park

3D Vision Enthusiast @ KAIST Visual Intelligence Lab. Here's a link to my GitHub page https://github.com/parkie0517