SE3 exponential and logarithm maps.

Summary:
Implements the SE3 logarithm and exponential maps.
(this is a second part of the split of D23326429)

Outputs of `bm_se3`:
```
Benchmark         Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
SE3_EXP_1                738             885            678
SE3_EXP_10               717             877            698
SE3_EXP_100              718             847            697
SE3_EXP_1000             729            1181            686
--------------------------------------------------------------------------------

Benchmark          Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
SE3_LOG_1               1451            2267            345
SE3_LOG_10              2185            2453            229
SE3_LOG_100             2217            2448            226
SE3_LOG_1000            2455            2599            204
--------------------------------------------------------------------------------
```

Reviewed By: patricklabatut

Differential Revision: D27852557

fbshipit-source-id: e42ccc9cfffe780e9cad24129de15624ae818472
This commit is contained in:
David Novotny 2021-06-21 04:47:31 -07:00 committed by Facebook GitHub Bot
parent 9f14e82b5a
commit b2ac2655b3
5 changed files with 561 additions and 2 deletions

View File

@ -20,6 +20,7 @@ from .rotation_conversions import (
rotation_6d_to_matrix,
standardize_quaternion,
)
from .se3 import se3_exp_map, se3_log_map
from .so3 import (
so3_exponential_map,
so3_exp_map,

213
pytorch3d/transforms/se3.py Normal file
View File

@ -0,0 +1,213 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import torch
from .so3 import hat, _so3_exp_map, so3_log_map
def se3_exp_map(log_transform: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
"""
Convert a batch of logarithmic representations of SE(3) matrices `log_transform`
to a batch of 4x4 SE(3) matrices using the exponential map.
See e.g. [1], Sec 9.4.2. for more detailed description.
A SE(3) matrix has the following form:
```
[ R 0 ]
[ T 1 ] ,
```
where `R` is a 3x3 rotation matrix and `T` is a 3-D translation vector.
SE(3) matrices are commonly used to represent rigid motions or camera extrinsics.
In the SE(3) logarithmic representation SE(3) matrices are
represented as 6-dimensional vectors `[log_translation | log_rotation]`,
i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`.
The conversion from the 6D representation to a 4x4 SE(3) matrix `transform`
is done as follows:
```
transform = exp( [ hat(log_rotation) 0 ]
[ log_translation 1 ] ) ,
```
where `exp` is the matrix exponential and `hat` is the Hat operator [2].
Note that for any `log_transform` with `0 <= ||log_rotation|| < 2pi`
(i.e. the rotation angle is between 0 and 2pi), the following identity holds:
```
se3_log_map(se3_exponential_map(log_transform)) == log_transform
```
The conversion has a singularity around `||log(transform)|| = 0`
which is handled by clamping controlled with the `eps` argument.
Args:
log_transform: Batch of vectors of shape `(minibatch, 6)`.
eps: A threshold for clipping the squared norm of the rotation logarithm
to avoid unstable gradients in the singular case.
Returns:
Batch of transformation matrices of shape `(minibatch, 4, 4)`.
Raises:
ValueError if `log_transform` is of incorrect shape.
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
[2] https://en.wikipedia.org/wiki/Hat_operator
"""
if log_transform.ndim != 2 or log_transform.shape[1] != 6:
raise ValueError("Expected input to be of shape (N, 6).")
N, _ = log_transform.shape
log_translation = log_transform[..., :3]
log_rotation = log_transform[..., 3:]
# rotation is an exponential map of log_rotation
(
R,
rotation_angles,
log_rotation_hat,
log_rotation_hat_square,
) = _so3_exp_map(log_rotation, eps=eps)
# translation is V @ T
V = _se3_V_matrix(
log_rotation,
log_rotation_hat,
log_rotation_hat_square,
rotation_angles,
eps=eps,
)
T = torch.bmm(V, log_translation[:, :, None])[:, :, 0]
transform = torch.zeros(
N, 4, 4, dtype=log_transform.dtype, device=log_transform.device
)
transform[:, :3, :3] = R
transform[:, :3, 3] = T
transform[:, 3, 3] = 1.0
return transform.permute(0, 2, 1)
def se3_log_map(
transform: torch.Tensor, eps: float = 1e-4, cos_bound: float = 1e-4
) -> torch.Tensor:
"""
Convert a batch of 4x4 transformation matrices `transform`
to a batch of 6-dimensional SE(3) logarithms of the SE(3) matrices.
See e.g. [1], Sec 9.4.2. for more detailed description.
A SE(3) matrix has the following form:
```
[ R 0 ]
[ T 1 ] ,
```
where `R` is an orthonormal 3x3 rotation matrix and `T` is a 3-D translation vector.
SE(3) matrices are commonly used to represent rigid motions or camera extrinsics.
In the SE(3) logarithmic representation SE(3) matrices are
represented as 6-dimensional vectors `[log_translation | log_rotation]`,
i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`.
The conversion from the 4x4 SE(3) matrix `transform` to the
6D representation `log_transform = [log_translation | log_rotation]`
is done as follows:
```
log_transform = log(transform)
log_translation = log_transform[3, :3]
log_rotation = inv_hat(log_transform[:3, :3])
```
where `log` is the matrix logarithm
and `inv_hat` is the inverse of the Hat operator [2].
Note that for any valid 4x4 `transform` matrix, the following identity holds:
```
se3_exp_map(se3_log_map(transform)) == transform
```
The conversion has a singularity around `(transform=I)` which is handled
by clamping controlled with the `eps` and `cos_bound` arguments.
Args:
transform: batch of SE(3) matrices of shape `(minibatch, 4, 4)`.
eps: A threshold for clipping the squared norm of the rotation logarithm
to avoid division by zero in the singular case.
cos_bound: Clamps the cosine of the rotation angle to
[-1 + cos_bound, 3 - cos_bound] to avoid non-finite outputs.
The non-finite outputs can be caused by passing small rotation angles
to the `acos` function in `so3_rotation_angle` of `so3_log_map`.
Returns:
Batch of logarithms of input SE(3) matrices
of shape `(minibatch, 6)`.
Raises:
ValueError if `transform` is of incorrect shape.
ValueError if `R` has an unexpected trace.
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
[2] https://en.wikipedia.org/wiki/Hat_operator
"""
if transform.ndim != 3:
raise ValueError("Input tensor shape has to be (N, 4, 4).")
N, dim1, dim2 = transform.shape
if dim1 != 4 or dim2 != 4:
raise ValueError("Input tensor shape has to be (N, 4, 4).")
if not torch.allclose(transform[:, :3, 3], torch.zeros_like(transform[:, :3, 3])):
raise ValueError("All elements of `transform[:, :3, 3]` should be 0.")
# log_rot is just so3_log_map of the upper left 3x3 block
R = transform[:, :3, :3].permute(0, 2, 1)
log_rotation = so3_log_map(R, eps=eps, cos_bound=cos_bound)
# log_translation is V^-1 @ T
T = transform[:, 3, :3]
V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps)
log_translation = torch.linalg.solve(V, T[:, :, None])[:, :, 0]
return torch.cat((log_translation, log_rotation), dim=1)
def _se3_V_matrix(
log_rotation: torch.Tensor,
log_rotation_hat: torch.Tensor,
log_rotation_hat_square: torch.Tensor,
rotation_angles: torch.Tensor,
eps: float = 1e-4,
) -> torch.Tensor:
"""
A helper function that computes the "V" matrix from [1], Sec 9.4.2.
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
"""
V = (
torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None]
+ log_rotation_hat
* ((1 - torch.cos(rotation_angles)) / (rotation_angles ** 2))[:, None, None]
+ (
log_rotation_hat_square
* ((rotation_angles - torch.sin(rotation_angles)) / (rotation_angles ** 3))[
:, None, None
]
)
)
return V
def _get_se3_V_input(log_rotation: torch.Tensor, eps: float = 1e-4):
"""
A helper function that computes the input variables to the `_se3_V_matrix`
function.
"""
nrms = (log_rotation ** 2).sum(-1)
rotation_angles = torch.clamp(nrms, eps).sqrt()
log_rotation_hat = hat(log_rotation)
log_rotation_hat_square = torch.bmm(log_rotation_hat, log_rotation_hat)
return log_rotation, log_rotation_hat, log_rotation_hat_square, rotation_angles

View File

@ -14,6 +14,7 @@ def so3_relative_angle(
R2: torch.Tensor,
cos_angle: bool = False,
cos_bound: float = 1e-4,
eps: float = 1e-4,
) -> torch.Tensor:
"""
Calculates the relative angle (in radians) between pairs of
@ -33,7 +34,8 @@ def so3_relative_angle(
of the `acos` call. Note that the non-finite outputs/gradients
are returned when the angle is requested (i.e. `cos_angle==False`)
and the rotation angle is close to 0 or π.
eps: Tolerance for the valid trace check of the relative rotation matrix
in `so3_rotation_angle`.
Returns:
Corresponding rotation angles of shape `(minibatch,)`.
If `cos_angle==True`, returns the cosine of the angles.
@ -43,7 +45,7 @@ def so3_relative_angle(
ValueError if `R1` or `R2` has an unexpected trace.
"""
R12 = torch.bmm(R1, R2.permute(0, 2, 1))
return so3_rotation_angle(R12, cos_angle=cos_angle, cos_bound=cos_bound)
return so3_rotation_angle(R12, cos_angle=cos_angle, cos_bound=cos_bound, eps=eps)
def so3_rotation_angle(

19
tests/bm_se3.py Normal file
View File

@ -0,0 +1,19 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
from fvcore.common.benchmark import benchmark
from test_se3 import TestSE3
def bm_se3() -> None:
kwargs_list = [
{"batch_size": 1},
{"batch_size": 10},
{"batch_size": 100},
{"batch_size": 1000},
]
benchmark(TestSE3.se3_expmap, "SE3_EXP", kwargs_list, warmup_iters=1)
benchmark(TestSE3.se3_logmap, "SE3_LOG", kwargs_list, warmup_iters=1)
if __name__ == "__main__":
bm_se3()

324
tests/test_se3.py Normal file
View File

@ -0,0 +1,324 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import unittest
import numpy as np
import torch
from common_testing import TestCaseMixin
from pytorch3d.transforms.rotation_conversions import random_rotations
from pytorch3d.transforms.se3 import se3_exp_map, se3_log_map
from pytorch3d.transforms.so3 import (
so3_exp_map,
so3_log_map,
so3_rotation_angle,
)
class TestSE3(TestCaseMixin, unittest.TestCase):
precomputed_log_transform = torch.tensor(
[
[0.1900, 2.1600, -0.1700, 0.8500, -1.9200, 0.6500],
[-0.6500, -0.8200, 0.5300, -1.2800, -1.6600, -0.3000],
[-0.0900, 0.2000, -1.1200, 1.8600, -0.7100, 0.6900],
[0.8000, -0.0300, 1.4900, -0.5200, -0.2500, 1.4700],
[-0.3300, -1.1600, 2.3600, -0.6900, 0.1800, -1.1800],
[-1.8000, -1.5800, 0.8400, 1.4200, 0.6500, 0.4300],
[-1.5900, 0.6200, 1.6900, -0.6600, 0.9400, 0.0800],
[0.0800, -0.1400, 0.3300, -0.5900, -1.0700, 0.1000],
[-0.3300, -0.5300, -0.8800, 0.3900, 0.1600, -0.2000],
[1.0100, -1.3500, -0.3500, -0.6400, 0.4500, -0.5400],
],
dtype=torch.float32,
)
precomputed_transform = torch.tensor(
[
[
[-0.3496, -0.2966, 0.8887, 0.0000],
[-0.7755, 0.6239, -0.0968, 0.0000],
[-0.5258, -0.7230, -0.4481, 0.0000],
[-0.7392, 1.9119, 0.3122, 1.0000],
],
[
[0.0354, 0.5992, 0.7998, 0.0000],
[0.8413, 0.4141, -0.3475, 0.0000],
[-0.5395, 0.6852, -0.4894, 0.0000],
[-0.9902, -0.4840, 0.1226, 1.0000],
],
[
[0.6664, -0.1679, 0.7264, 0.0000],
[-0.7309, -0.3394, 0.5921, 0.0000],
[0.1471, -0.9255, -0.3489, 0.0000],
[-0.0815, 0.8719, -0.4516, 1.0000],
],
[
[0.1010, 0.9834, -0.1508, 0.0000],
[-0.8783, 0.0169, -0.4779, 0.0000],
[-0.4674, 0.1807, 0.8654, 0.0000],
[0.2375, 0.7043, 1.4159, 1.0000],
],
[
[0.3935, -0.8930, 0.2184, 0.0000],
[0.7873, 0.2047, -0.5817, 0.0000],
[0.4747, 0.4009, 0.7836, 0.0000],
[-0.3476, -0.0424, 2.5408, 1.0000],
],
[
[0.7572, 0.6342, -0.1567, 0.0000],
[0.1039, 0.1199, 0.9873, 0.0000],
[0.6449, -0.7638, 0.0249, 0.0000],
[-1.2885, -2.0666, -0.1137, 1.0000],
],
[
[0.6020, -0.2140, -0.7693, 0.0000],
[-0.3409, 0.8024, -0.4899, 0.0000],
[0.7221, 0.5572, 0.4101, 0.0000],
[-0.7550, 1.1928, 1.8480, 1.0000],
],
[
[0.4913, 0.3548, 0.7954, 0.0000],
[0.2013, 0.8423, -0.5000, 0.0000],
[-0.8474, 0.4058, 0.3424, 0.0000],
[-0.1003, -0.0406, 0.3295, 1.0000],
],
[
[0.9678, -0.1622, -0.1926, 0.0000],
[0.2235, 0.9057, 0.3603, 0.0000],
[0.1160, -0.3917, 0.9128, 0.0000],
[-0.4417, -0.3111, -0.9227, 1.0000],
],
[
[0.7710, -0.5957, -0.2250, 0.0000],
[0.3288, 0.6750, -0.6605, 0.0000],
[0.5454, 0.4352, 0.7163, 0.0000],
[0.5623, -1.5886, -0.0182, 1.0000],
],
],
dtype=torch.float32,
)
def setUp(self) -> None:
super().setUp()
torch.manual_seed(42)
np.random.seed(42)
@staticmethod
def init_log_transform(batch_size: int = 10):
"""
Initialize a list of `batch_size` 6-dimensional vectors representing
randomly generated logarithms of SE(3) transforms.
"""
device = torch.device("cuda:0")
log_rot = torch.randn((batch_size, 6), dtype=torch.float32, device=device)
return log_rot
@staticmethod
def init_transform(batch_size: int = 10):
"""
Initialize a list of `batch_size` 4x4 SE(3) transforms.
"""
device = torch.device("cuda:0")
transform = torch.zeros(batch_size, 4, 4, dtype=torch.float32, device=device)
transform[:, :3, :3] = random_rotations(
batch_size, dtype=torch.float32, device=device
)
transform[:, 3, :3] = torch.randn(
(batch_size, 3), dtype=torch.float32, device=device
)
transform[:, 3, 3] = 1.0
return transform
def test_se3_exp_output_format(self, batch_size: int = 100):
"""
Check that the output of `se3_exp_map` is a valid SE3 matrix.
"""
transform = se3_exp_map(TestSE3.init_log_transform(batch_size=batch_size))
R = transform[:, :3, :3]
T = transform[:, 3, :3]
rest = transform[:, :, 3]
Rdet = R.det()
# check det(R)==1
self.assertClose(Rdet, torch.ones_like(Rdet), atol=1e-4)
# check that the translation is a finite vector
self.assertTrue(torch.isfinite(T).all())
# check last column == [0,0,0,1]
last_col = rest.new_zeros(batch_size, 4)
last_col[:, -1] = 1.0
self.assertClose(rest, last_col)
def test_compare_with_precomputed(self):
"""
Compare the outputs against precomputed results.
"""
self.assertClose(
se3_log_map(self.precomputed_transform),
self.precomputed_log_transform,
atol=1e-4,
)
self.assertClose(
self.precomputed_transform,
se3_exp_map(self.precomputed_log_transform),
atol=1e-4,
)
def test_se3_exp_singularity(self, batch_size: int = 100):
"""
Tests whether the `se3_exp_map` is robust to the input vectors
with low L2 norms, where the algorithm is numerically unstable.
"""
# generate random log-rotations with a tiny angle
log_rot = TestSE3.init_log_transform(batch_size=batch_size)
log_rot_small = log_rot * 1e-6
log_rot_small.requires_grad = True
transforms = se3_exp_map(log_rot_small)
# tests whether all outputs are finite
self.assertTrue(torch.isfinite(transforms).all())
# tests whether all gradients are finite and not None
loss = transforms.sum()
loss.backward()
self.assertIsNotNone(log_rot_small.grad)
self.assertTrue(torch.isfinite(log_rot_small.grad).all())
def test_se3_log_singularity(self, batch_size: int = 100):
"""
Tests whether the `se3_log_map` is robust to the input matrices
whose rotation angles and translations are close to the numerically
unstable region (i.e. matrices with low rotation angles
and 0 translation).
"""
# generate random rotations with a tiny angle
device = torch.device("cuda:0")
identity = torch.eye(3, device=device)
rot180 = identity * torch.tensor([[1.0, -1.0, -1.0]], device=device)
r = [identity, rot180]
r.extend(
[
torch.qr(identity + torch.randn_like(identity) * 1e-6)[0]
+ float(i > batch_size // 2) * (0.5 - torch.rand_like(identity)) * 1e-8
# this adds random noise to the second half
# of the random orthogonal matrices to generate
# near-orthogonal matrices
for i in range(batch_size - 2)
]
)
r = torch.stack(r)
# tiny translations
t = torch.randn(batch_size, 3, dtype=r.dtype, device=device) * 1e-6
# create the transform matrix
transform = torch.zeros(batch_size, 4, 4, dtype=torch.float32, device=device)
transform[:, :3, :3] = r
transform[:, 3, :3] = t
transform[:, 3, 3] = 1.0
transform.requires_grad = True
# the log of the transform
log_transform = se3_log_map(transform, eps=1e-4, cos_bound=1e-4)
# tests whether all outputs are finite
self.assertTrue(torch.isfinite(log_transform).all())
# tests whether all gradients are finite and not None
loss = log_transform.sum()
loss.backward()
self.assertIsNotNone(transform.grad)
self.assertTrue(torch.isfinite(transform.grad).all())
def test_se3_exp_zero_translation(self, batch_size: int = 100):
"""
Check that `se3_exp_map` with zero translation gives
the same result as corresponding `so3_exp_map`.
"""
log_transform = TestSE3.init_log_transform(batch_size=batch_size)
log_transform[:, :3] *= 0.0
transform = se3_exp_map(log_transform, eps=1e-8)
transform_so3 = so3_exp_map(log_transform[:, 3:], eps=1e-8)
self.assertClose(
transform[:, :3, :3], transform_so3.permute(0, 2, 1), atol=1e-4
)
self.assertClose(
transform[:, 3, :3], torch.zeros_like(transform[:, :3, 3]), atol=1e-4
)
def test_se3_log_zero_translation(self, batch_size: int = 100):
"""
Check that `se3_log_map` with zero translation gives
the same result as corresponding `so3_log_map`.
"""
transform = TestSE3.init_transform(batch_size=batch_size)
transform[:, 3, :3] *= 0.0
log_transform = se3_log_map(transform, eps=1e-8, cos_bound=1e-4)
log_transform_so3 = so3_log_map(transform[:, :3, :3], eps=1e-8, cos_bound=1e-4)
self.assertClose(log_transform[:, 3:], -log_transform_so3, atol=1e-4)
self.assertClose(
log_transform[:, :3], torch.zeros_like(log_transform[:, :3]), atol=1e-4
)
def test_se3_exp_to_log_to_exp(self, batch_size: int = 10000):
"""
Check that `se3_exp_map(se3_log_map(A))==A` for
a batch of randomly generated SE(3) matrices `A`.
"""
transform = TestSE3.init_transform(batch_size=batch_size)
# Limit test transforms to those not around the singularity where
# the rotation angle~=pi.
nonsingular = so3_rotation_angle(transform[:, :3, :3]) < 3.134
transform = transform[nonsingular]
transform_ = se3_exp_map(
se3_log_map(transform, eps=1e-8, cos_bound=0.0), eps=1e-8
)
self.assertClose(transform, transform_, atol=0.02)
def test_se3_log_to_exp_to_log(self, batch_size: int = 100):
"""
Check that `se3_log_map(se3_exp_map(log_transform))==log_transform`
for a randomly generated batch of SE(3) matrix logarithms `log_transform`.
"""
log_transform = TestSE3.init_log_transform(batch_size=batch_size)
log_transform_ = se3_log_map(se3_exp_map(log_transform, eps=1e-8), eps=1e-8)
self.assertClose(log_transform, log_transform_, atol=1e-1)
def test_bad_se3_input_value_err(self):
"""
Tests whether `se3_exp_map` and `se3_log_map` correctly return
a ValueError if called with an argument of incorrect shape, or with
an tensor containing illegal values.
"""
device = torch.device("cuda:0")
for size in ([5, 4], [3, 4, 5], [3, 5, 6]):
log_transform = torch.randn(size=size, device=device)
with self.assertRaises(ValueError):
se3_exp_map(log_transform)
for size in ([5, 4], [3, 4, 5], [3, 5, 6], [2, 2, 3, 4]):
transform = torch.randn(size=size, device=device)
with self.assertRaises(ValueError):
se3_log_map(transform)
# Test the case where transform[:, :, :3] != 0.
transform = torch.rand(size=[5, 4, 4], device=device) + 0.1
with self.assertRaises(ValueError):
se3_log_map(transform)
@staticmethod
def se3_expmap(batch_size: int = 10):
log_transform = TestSE3.init_log_transform(batch_size=batch_size)
torch.cuda.synchronize()
def compute_transforms():
se3_exp_map(log_transform)
torch.cuda.synchronize()
return compute_transforms
@staticmethod
def se3_logmap(batch_size: int = 10):
log_transform = TestSE3.init_transform(batch_size=batch_size)
torch.cuda.synchronize()
def compute_logs():
se3_log_map(log_transform)
torch.cuda.synchronize()
return compute_logs