Point Net for Classification

How to Train Point Net for Point Cloud Classification

Isaac Berrios
9 min readDec 11, 2022
Photo by Christoph von Gellhorn on Unsplash

This is the third part of the Point Net Series:

  1. An Intuitive Introduction to Point Net
  2. Point Net from Scratch
  3. Point Net for Classification
  4. Point Net for Semantic Segmentation

In this tutorial we will learn how to train Point Net for Classification. We will focus mainly on the data and training process; the tutorial that shows how to code Point Net from scratch is located here. The code for this tutorial is located in this repository. The notebook we will be using is located here within that repository. Some of the code was inspired from this repository.

Getting the Data

We will be working with a smaller version of the shapenet dataset that only has 16 classes. If you are using Colab, you can run the following code to obtain the data. WARNING, this will take a long time.

!wget -nv https://shapenet.cs.stanford.edu/ericyi/shapenetcore_partanno_segmentation_benchmark_v0.zip --no-check-certificate
!unzip shapenetcore_partanno_segmentation_benchmark_v0.zip
!rm shapenetcore_partanno_segmentation_benchmark_v0.zip

If you would like to work locally visit the link on the first line above and the data will automatically be downloaded as a zip file.

The dataset contains 16 folders with class identifiers (called “synsetoffset” in the README). The folder structure is:

  • synsetoffset
    - points: uniformly sampled points from ShapeNetCore models
    - point_labels: per-point segmentation labels
    - seg_img: a visualization of labeling
  • train_test_split: JSON files with train/validation/test splits

A custom PyTorch dataset is located here, explaining the code is outside the scope of this tutorial. The important things to know is that the dataset can obtain either (point_cloud, class) or (point_cloud, seg_labels). During Training and Validation we add Gaussian Noise to the point clouds and randomly rotate them about the vertical axis (y-axis in this case). We also perform min-max normalization to the point clouds so that they have a range of 0–1. We can create an instance of the shapenet dataset like so:

from shapenet_dataset import ShapenetDataset

# __getitem__ returns (point_cloud, class)
train_dataset = ShapenetDataset(ROOT, npoints=2500, split='train', classification=True)

Exploring the data

Before we get started with any training, let’s explore some of the training data. To do this we will use Open3d version 0.16.0 (must 0.16.0 or higher).

!pip install open3d==0.16.0

We can now view a sample point cloud with the following code. You should notice that the point cloud is displayed in a different orientation each time you run the code.

import open3d as o3
from shapenet_dataset import ShapenetDataset

sample_dataset = train_dataset = ShapenetDataset(ROOT, npoints=20000, split='train',
classification=False, normalize=False)

points, seg = sample_dataset[4000]

pcd = o3.geometry.PointCloud()
pcd.points = o3.utility.Vector3dVector(points)
pcd.colors = o3.utility.Vector3dVector(read_pointnet_colors(seg.numpy()))

o3.visualization.draw_plotly([pcd])
Figure 1. Noisy point cloud with random rotation. Y-axis is the vertical axis. Source: Author.

You probably won’t notice much of a difference with the noise since we are adding such a small amount; we add a small amount since don’t want to greatly disrupt the structures, but this small amount is enough to have an impact on the model. Now let’s take a look at the training class frequencies.

Figure 2. Frequencies of the training classes. Source: Author.

We can see in figure 2 that this is definitely not a balanced training set. Because of this we may want to apply class weighting or even use the Focal Loss to help our model learn.

Point Net Loss function

When Training Point Net for Classification we can use the standard Cross Entropy Loss from PyTorch, but we also want to make an addition to include the regularization term mentioned in the paper [1]. The regularization term forces the feature transformation matrix to be orthogonal, but why? The feature transform matrix is intended to rotate (transform) the High Dimensional Representation of the Point Cloud. How can we be sure that this learned high dimensional rotation is actually rotating the point cloud? To answer this, let’s consider some desired properties of the rotation. We want the learned rotation to be affine, meaning that it preserves structure. We want to be sure that it is not doing something strange, like mapping it back to a lower dimensional space or messing up the structure. We can’t just plot an nx64 point cloud to check this, but we can get the model to learn a valid rotation by encouraging the rotation to be orthogonal. This is because orthogonal matrices preserve both length and angle, and rotation matrices are a special type of orthogonal matrix [2]. We can ‘encourage’ the model to learn an orthogonal rotation matrix through regularization with the term:

Figure 3. Point Net Regularization term. Source.

We exploit a fundamental property of orthogonal matrices, which is that their columns and rows are orthogonal vectors. The regularization term in figure 3, will be equal to zero for a perfectly orthogonal matrix. [2]

During training we simply add this term to our loss. If you have gone through the previous tutorial on how to code a point net, you may remember that the feature transformation matrix A is returned by the classification head.

Now let’s code the Point Net loss function. We’ve gone ahead and added terms for weighted (balanced) Cross Entropy Loss and Focal Loss, but explaining them is outside the scope of this tutorial. The code for this is located here. This code was adapted from this repository.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


class PointNetLoss(nn.Module):
def __init__(self, alpha=None, gamma=0, reg_weight=0, size_average=True):
super(PointNetLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reg_weight = reg_weight
self.size_average = size_average

# sanitize inputs
if isinstance(alpha,(float, int)): self.alpha = torch.Tensor([alpha,1-alpha])
if isinstance(alpha,(list, np.ndarray)): self.alpha = torch.Tensor(alpha)

# get Balanced Cross Entropy Loss
self.cross_entropy_loss = nn.CrossEntropyLoss(weight=self.alpha)

def forward(self, predictions, targets, A):

# get batch size
bs = predictions.size(0)

# get Balanced Cross Entropy Loss
ce_loss = self.cross_entropy_loss(predictions, targets)

# reformat predictions and targets (segmentation only)
if len(predictions.shape) > 2:
predictions = predictions.transpose(1, 2) # (b, c, n) -> (b, n, c)
predictions = predictions.contiguous() \
.view(-1, predictions.size(2)) # (b, n, c) -> (b*n, c)

# get predicted class probabilities for the true class
pn = F.softmax(predictions)
pn = pn.gather(1, targets.view(-1, 1)).view(-1)

# get regularization term
if self.reg_weight > 0:
I = torch.eye(64).unsqueeze(0).repeat(A.shape[0], 1, 1) # .to(device)
if A.is_cuda: I = I.cuda()
reg = torch.linalg.norm(I - torch.bmm(A, A.transpose(2, 1)))
reg = self.reg_weight*reg/bs
else:
reg = 0

# compute loss (negative sign is included in ce_loss)
loss = ((1 - pn)**self.gamma * ce_loss)
if self.size_average: return loss.mean() + reg
else: return loss.sum() + reg

Training Point Net for Classification

Now that we have an understanding of the data and the loss function, we can move on to the training. For our training we will want to quantify how well our model is performing. Typically we look at loss and accuracy, but for this classification problem we will need a metric that accounts for incorrect classification as well as correct classification. Think of the typical confusion matrix: true positives, false negatives, true negatives, and false positives; we want a classifier that performs well on all of these. The Matthews Correlation Coefficient (MCC) quantifies how well our model performs on all of these metrics and is considered to be a more reliable single metric of performance than accuracy or F1 score [3]. The MCC ranges from -1 to 1, where -1 is the worst performance, 1 is the best performance, and 0 is a Random Guess. We can use the MCC with PyTorch via torchmetrics.

from torchmetrics.classification import MulticlassMatthewsCorrCoef

mcc_metric = MulticlassMatthewsCorrCoef(num_classes=NUM_CLASSES).to(DEVICE)

The training process is a basic PyTorch training loop that alternates between training and validation. We use the Adam optimizer and our Point Net Loss function plus the regularization term described above in figure 3. For the Point Net loss function we choose to set alpha which weights the importance of each sample. We also set gamma which modulates the loss function and forces it to focus on the hard examples, where hard examples are those that are classified with a lower probability. See notes in the notebook for more details. It was noticed that the model trained better when using a cyclic learning rate, so we implemented it here.

import torch.optim as optim
from point_net_loss import PointNetLoss

EPOCHS = 50
LR = 0.0001
REG_WEIGHT = 0.001

# manually downweight the high frequency classes
alpha = np.ones(NUM_CLASSES)
alpha[0] = 0.5 # airplane
alpha[4] = 0.5 # chair
alpha[-1] = 0.5 # table

gamma = 1

optimizer = optim.Adam(classifier.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.0001, max_lr=0.01,
step_size_up=2000, cycle_momentum=False)
criterion = PointNetLoss(alpha=alpha, gamma=gamma, reg_weight=REG_WEIGHT).to(DEVICE)

classifier = classifier.to(DEVICE)

Please follow the notebook for the training loop and make sure you have a GPU. If not, remove the scheduler and set the learning rate to 0.01, you should get decent enough results after a few epochs. If you run into any PyTorch user warnings (due to a future update to nn.MaxPool1D), you can suppress them with:

import warnings
warnings.filterwarnings("ignore")

Training Results

Figure 4. Training metrics. Source: Author.

We can see that the accuracy goes up for both training and validation, but the MCC only goes up for training and not validation. This could be caused by classes with very small sample sizes for some of the classes in the validation and test splits; so in this case the MCC may not be the best single metric for validation and test. This warrants some more investigation as to when the MCC is a good metric; i.e. how much imbalance is too much for the MCC? How many samples does each class need for the MCC to be effective?

Let’s look at the test results:

Figure 5. Test metrics. Source: Author.

We see that the test accuracy is decent around 85%, but the MCC is just above 0. Since we only have 16 classes, let‘s view the confusion matrix in the notebook to get some more insight into the test results.

Figure 6. Test data Confusion matrix. Source: Author.

For the most part the classification is okay, but there are a few less common classes such as ‘rocket’ or ‘skateboard’. The model tends to have poor predictive performance on these classes and performance on these less common classes is what’s driving the MCC down. Another thing to notice, is that when you inspect the results (as shown in the notebook) you will get good accuracy and confident performance on the more frequent classes. However in the less frequent classes you will notice that the confidence is lower and the accuracy is worse.

Inspecting the critical sets

Now we will look into a most interesting part of this tutorial, the critical sets. The critical sets are the essential underlying points of a point cloud set. These points define the basic structure of it. Here’s some code showing how to visualize them.

from open3d.web_visualizer import draw 


critical_points = points[crit_idxs.squeeze(), :]
critical_point_colors = read_pointnet_colors(seg.numpy())[crit_idxs.cpu().squeeze(), :]

pcd = o3.geometry.PointCloud()
pcd.points = o3.utility.Vector3dVector(critical_points)
pcd.colors = o3.utility.Vector3dVector(critical_point_colors)

# o3.visualization.draw_plotly([pcd])
draw(pcd, point_size=5) # does not work in Colab

And here are some visualizations, Note I used ‘draw()’ to get larger point sizes, but it does not work in Colab.

Figure 7. Point Cloud Sets and their corresponding critical sets learned by Point Net. Source: Author.

We can see that the Critical Sets show the overall structure of their corresponding point clouds, they are essentially sparsely sampled point clouds. This indicates that the trained model has actually learned to distinguish the difference structures and suggests that it is actually able to classify each point cloud category based on their distinguishing structures.

Conclusion

If you have made it this far, congratulations! You have learned how to train a Point Net from scratch and we have even learned how to visualize the Point Sets. I would encourage to you to go back and make sure you understand everything, and if you are really interested try to improve the overall classification performance. Here are some suggestions to get you started:

  • Using a different loss function
  • Try different setting in the cyclic learning rate scheduler
  • Experiment with modifications to the Point Net architecture
  • Experiment with different data augmentations
  • Use more data → Try the full shapenet datset

References

[1] Charles, R. Q., Su, H., Kaichun, M., & Guibas, L. J. (2017). PointNet: Deep Learning on point sets for 3D classification and segmentation. 2017 IEEE Conference on Computer Vision and Pattern Recognition (CVPR). https://doi.org/10.1109/cvpr.2017.16

[2] Knill, O. (n.d.). Unit 8: The Orthogonal Group — Harvard University. people.math.harvard.edu. Retrieved December 10, 2022, from https://people.math.harvard.edu/~knill/teaching/math22b2019/handouts/lecture08.pdf

[3] Chicco, D., & Jurman, G. (2020). The advantages of the Matthews correlation coefficient (MCC) over F1 score and accuracy in binary classification evaluation. BMC Genomics, 21(1). https://doi.org/10.1186/s12864-019-6413-7

--

--

Isaac Berrios

Electrical Engineer interested in Sensor Fusion, Computer Vision, and Autonomy