3D rotations and spatial transformations made easy with RoMa
Struggling with quaternions, rotation vectors, right-hand rules and all these stuffs? Try RoMa: an easy-to-to-use, stable and efficient library to deal with rotations and spatial transformations in PyTorch.
A bit of context
Spatial transformations are essential in physics, engineering, and computer vision. Scientists have been studying 3D geometry for centuries and produced great mathematical tools to deal with such transformations, especially rigid motions and rotations. Yet, implementing these tools properly is not trivial.
Indeed, one often has to use different conventions or representations for different tasks. Quaternions are for example great to compose rotations; rotation matrices are ideal to transform coordinates of large number of points; rotation vectors are compact and useful to represent small rotations, and Euler angles are bad for pretty much everything except user inputs.
Additionally, turning math into code leads to numerical issues that can be hard to spot, and cause subtle bugs. Hopefully, these issues can be mitigated by picking the right algorithms or by handling special cases. For example, the angle between two rotation matrices can be computed based on different formulas that are mathematically equivalent, but produce drastically different numerical errors in 32bit floating point precision:
Introducing RoMa
RoMa aims to overcome these obstacles for you. RoMa is a Python library compatible with PyTorch version 1.6 and above. It provides an easy-to-use, stable and efficient toolbox to deal with rotations, as well as more general spatial transformations.
Conversions between rotation representations
RoMa provides differentiable routines to convert between various rotation representations. For example, one could sample a 3D rotation vector and convert it to a unit quaternion as follows using RoMa:
import torch, roma
rotvec = torch.randn(3) # 3D rotation vector
q = roma.rotvec_to_unitquat(rotvec) # unit quaternion, represented by a 4D tensor
Batched data
For convenience, functions in RoMa support arbitrary numbers of batch dimensions. One could for example sample a batch of 2x5 random 3D rotation vectors — represented by a 2x5x3 tensor — and convert it to a batch of unit quaternions — represented by a 2x5x4 tensor — using the same syntax as above:
import torch, roma
rotvec = torch.randn(2,5,3)
q = roma.rotvec_to_unitquat(rotvec)
Regressing rotations
Regressing rotations using a neural network is non trivial because classical neural architectures produce outputs lying in Euclidean space. RoMa implements various differentiable functions to map such output to the rotation space, e.g. by performing special Procrustes orthonormalization of an arbitrary matrix:
import torch, roma
M = my_fancy_neural_network(some_input) # Method returning an arbitrary 3x3 matrix
R = roma.special_procrustes(M) # Orthonormalizing M into a rotation matrix
assert roma.is_rotation_matrix(R, epsilon=1e-5)
Rigid transformations
For a more readable code, RoMa also includes utilities to deal with nonlinear spatial transformations, and notably rigid motions:
import torch, roma
# Rigid transformation parameterized by a rotation matrix and a translation vector
T1 = roma.Rigid(linear=roma.random_rotmat(), translation=torch.randn(3))
T2 = roma.Rigid(linear=roma.random_rotmat(), translation=torch.randn(3))
# Inverting and composing transformations
T = (T1.inverse() @ T2)
# Normalization to ensure that T is actually a rigid transformation.
T = T.normalize()
# Direct access to the translation part
T.translation += 0.5
# Transformation of points:
points = torch.randn(100,3)
# Adjusting the shape of T for proper broadcasting.
transformed_points = T[None].apply(points)
# Transformation of vectors:
vectors = torch.randn(10,20,3)
# Adjusting the shape of T for proper broadcasting.
transformed_vectors = T[None,None].linear_apply(vectors)
# Casting the transformation into an homogeneous 4x4 matrix.
M = T.to_homogeneous()
RoMa can be installed easily using PIP, so give it a try:
pip install roma
and see the documentation and Github repository for more details.