mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 14:50:36 +08:00
SE3 exponential and logarithm maps.
Summary: Implements the SE3 logarithm and exponential maps. (this is a second part of the split of D23326429) Outputs of `bm_se3`: ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- SE3_EXP_1 738 885 678 SE3_EXP_10 717 877 698 SE3_EXP_100 718 847 697 SE3_EXP_1000 729 1181 686 -------------------------------------------------------------------------------- Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- SE3_LOG_1 1451 2267 345 SE3_LOG_10 2185 2453 229 SE3_LOG_100 2217 2448 226 SE3_LOG_1000 2455 2599 204 -------------------------------------------------------------------------------- ``` Reviewed By: patricklabatut Differential Revision: D27852557 fbshipit-source-id: e42ccc9cfffe780e9cad24129de15624ae818472
This commit is contained in:
committed by
Facebook GitHub Bot
parent
9f14e82b5a
commit
b2ac2655b3
@@ -20,6 +20,7 @@ from .rotation_conversions import (
|
||||
rotation_6d_to_matrix,
|
||||
standardize_quaternion,
|
||||
)
|
||||
from .se3 import se3_exp_map, se3_log_map
|
||||
from .so3 import (
|
||||
so3_exponential_map,
|
||||
so3_exp_map,
|
||||
|
||||
213
pytorch3d/transforms/se3.py
Normal file
213
pytorch3d/transforms/se3.py
Normal file
@@ -0,0 +1,213 @@
|
||||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
import torch
|
||||
|
||||
from .so3 import hat, _so3_exp_map, so3_log_map
|
||||
|
||||
|
||||
def se3_exp_map(log_transform: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
|
||||
"""
|
||||
Convert a batch of logarithmic representations of SE(3) matrices `log_transform`
|
||||
to a batch of 4x4 SE(3) matrices using the exponential map.
|
||||
See e.g. [1], Sec 9.4.2. for more detailed description.
|
||||
|
||||
A SE(3) matrix has the following form:
|
||||
```
|
||||
[ R 0 ]
|
||||
[ T 1 ] ,
|
||||
```
|
||||
where `R` is a 3x3 rotation matrix and `T` is a 3-D translation vector.
|
||||
SE(3) matrices are commonly used to represent rigid motions or camera extrinsics.
|
||||
|
||||
In the SE(3) logarithmic representation SE(3) matrices are
|
||||
represented as 6-dimensional vectors `[log_translation | log_rotation]`,
|
||||
i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`.
|
||||
|
||||
The conversion from the 6D representation to a 4x4 SE(3) matrix `transform`
|
||||
is done as follows:
|
||||
```
|
||||
transform = exp( [ hat(log_rotation) 0 ]
|
||||
[ log_translation 1 ] ) ,
|
||||
```
|
||||
where `exp` is the matrix exponential and `hat` is the Hat operator [2].
|
||||
|
||||
Note that for any `log_transform` with `0 <= ||log_rotation|| < 2pi`
|
||||
(i.e. the rotation angle is between 0 and 2pi), the following identity holds:
|
||||
```
|
||||
se3_log_map(se3_exponential_map(log_transform)) == log_transform
|
||||
```
|
||||
|
||||
The conversion has a singularity around `||log(transform)|| = 0`
|
||||
which is handled by clamping controlled with the `eps` argument.
|
||||
|
||||
Args:
|
||||
log_transform: Batch of vectors of shape `(minibatch, 6)`.
|
||||
eps: A threshold for clipping the squared norm of the rotation logarithm
|
||||
to avoid unstable gradients in the singular case.
|
||||
|
||||
Returns:
|
||||
Batch of transformation matrices of shape `(minibatch, 4, 4)`.
|
||||
|
||||
Raises:
|
||||
ValueError if `log_transform` is of incorrect shape.
|
||||
|
||||
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
|
||||
[2] https://en.wikipedia.org/wiki/Hat_operator
|
||||
"""
|
||||
|
||||
if log_transform.ndim != 2 or log_transform.shape[1] != 6:
|
||||
raise ValueError("Expected input to be of shape (N, 6).")
|
||||
|
||||
N, _ = log_transform.shape
|
||||
|
||||
log_translation = log_transform[..., :3]
|
||||
log_rotation = log_transform[..., 3:]
|
||||
|
||||
# rotation is an exponential map of log_rotation
|
||||
(
|
||||
R,
|
||||
rotation_angles,
|
||||
log_rotation_hat,
|
||||
log_rotation_hat_square,
|
||||
) = _so3_exp_map(log_rotation, eps=eps)
|
||||
|
||||
# translation is V @ T
|
||||
V = _se3_V_matrix(
|
||||
log_rotation,
|
||||
log_rotation_hat,
|
||||
log_rotation_hat_square,
|
||||
rotation_angles,
|
||||
eps=eps,
|
||||
)
|
||||
T = torch.bmm(V, log_translation[:, :, None])[:, :, 0]
|
||||
|
||||
transform = torch.zeros(
|
||||
N, 4, 4, dtype=log_transform.dtype, device=log_transform.device
|
||||
)
|
||||
|
||||
transform[:, :3, :3] = R
|
||||
transform[:, :3, 3] = T
|
||||
transform[:, 3, 3] = 1.0
|
||||
|
||||
return transform.permute(0, 2, 1)
|
||||
|
||||
|
||||
def se3_log_map(
|
||||
transform: torch.Tensor, eps: float = 1e-4, cos_bound: float = 1e-4
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert a batch of 4x4 transformation matrices `transform`
|
||||
to a batch of 6-dimensional SE(3) logarithms of the SE(3) matrices.
|
||||
See e.g. [1], Sec 9.4.2. for more detailed description.
|
||||
|
||||
A SE(3) matrix has the following form:
|
||||
```
|
||||
[ R 0 ]
|
||||
[ T 1 ] ,
|
||||
```
|
||||
where `R` is an orthonormal 3x3 rotation matrix and `T` is a 3-D translation vector.
|
||||
SE(3) matrices are commonly used to represent rigid motions or camera extrinsics.
|
||||
|
||||
In the SE(3) logarithmic representation SE(3) matrices are
|
||||
represented as 6-dimensional vectors `[log_translation | log_rotation]`,
|
||||
i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`.
|
||||
|
||||
The conversion from the 4x4 SE(3) matrix `transform` to the
|
||||
6D representation `log_transform = [log_translation | log_rotation]`
|
||||
is done as follows:
|
||||
```
|
||||
log_transform = log(transform)
|
||||
log_translation = log_transform[3, :3]
|
||||
log_rotation = inv_hat(log_transform[:3, :3])
|
||||
```
|
||||
where `log` is the matrix logarithm
|
||||
and `inv_hat` is the inverse of the Hat operator [2].
|
||||
|
||||
Note that for any valid 4x4 `transform` matrix, the following identity holds:
|
||||
```
|
||||
se3_exp_map(se3_log_map(transform)) == transform
|
||||
```
|
||||
|
||||
The conversion has a singularity around `(transform=I)` which is handled
|
||||
by clamping controlled with the `eps` and `cos_bound` arguments.
|
||||
|
||||
Args:
|
||||
transform: batch of SE(3) matrices of shape `(minibatch, 4, 4)`.
|
||||
eps: A threshold for clipping the squared norm of the rotation logarithm
|
||||
to avoid division by zero in the singular case.
|
||||
cos_bound: Clamps the cosine of the rotation angle to
|
||||
[-1 + cos_bound, 3 - cos_bound] to avoid non-finite outputs.
|
||||
The non-finite outputs can be caused by passing small rotation angles
|
||||
to the `acos` function in `so3_rotation_angle` of `so3_log_map`.
|
||||
|
||||
Returns:
|
||||
Batch of logarithms of input SE(3) matrices
|
||||
of shape `(minibatch, 6)`.
|
||||
|
||||
Raises:
|
||||
ValueError if `transform` is of incorrect shape.
|
||||
ValueError if `R` has an unexpected trace.
|
||||
|
||||
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
|
||||
[2] https://en.wikipedia.org/wiki/Hat_operator
|
||||
"""
|
||||
|
||||
if transform.ndim != 3:
|
||||
raise ValueError("Input tensor shape has to be (N, 4, 4).")
|
||||
|
||||
N, dim1, dim2 = transform.shape
|
||||
if dim1 != 4 or dim2 != 4:
|
||||
raise ValueError("Input tensor shape has to be (N, 4, 4).")
|
||||
|
||||
if not torch.allclose(transform[:, :3, 3], torch.zeros_like(transform[:, :3, 3])):
|
||||
raise ValueError("All elements of `transform[:, :3, 3]` should be 0.")
|
||||
|
||||
# log_rot is just so3_log_map of the upper left 3x3 block
|
||||
R = transform[:, :3, :3].permute(0, 2, 1)
|
||||
log_rotation = so3_log_map(R, eps=eps, cos_bound=cos_bound)
|
||||
|
||||
# log_translation is V^-1 @ T
|
||||
T = transform[:, 3, :3]
|
||||
V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps)
|
||||
log_translation = torch.linalg.solve(V, T[:, :, None])[:, :, 0]
|
||||
|
||||
return torch.cat((log_translation, log_rotation), dim=1)
|
||||
|
||||
|
||||
def _se3_V_matrix(
|
||||
log_rotation: torch.Tensor,
|
||||
log_rotation_hat: torch.Tensor,
|
||||
log_rotation_hat_square: torch.Tensor,
|
||||
rotation_angles: torch.Tensor,
|
||||
eps: float = 1e-4,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
A helper function that computes the "V" matrix from [1], Sec 9.4.2.
|
||||
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
|
||||
"""
|
||||
|
||||
V = (
|
||||
torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None]
|
||||
+ log_rotation_hat
|
||||
* ((1 - torch.cos(rotation_angles)) / (rotation_angles ** 2))[:, None, None]
|
||||
+ (
|
||||
log_rotation_hat_square
|
||||
* ((rotation_angles - torch.sin(rotation_angles)) / (rotation_angles ** 3))[
|
||||
:, None, None
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
return V
|
||||
|
||||
|
||||
def _get_se3_V_input(log_rotation: torch.Tensor, eps: float = 1e-4):
|
||||
"""
|
||||
A helper function that computes the input variables to the `_se3_V_matrix`
|
||||
function.
|
||||
"""
|
||||
nrms = (log_rotation ** 2).sum(-1)
|
||||
rotation_angles = torch.clamp(nrms, eps).sqrt()
|
||||
log_rotation_hat = hat(log_rotation)
|
||||
log_rotation_hat_square = torch.bmm(log_rotation_hat, log_rotation_hat)
|
||||
return log_rotation, log_rotation_hat, log_rotation_hat_square, rotation_angles
|
||||
@@ -14,6 +14,7 @@ def so3_relative_angle(
|
||||
R2: torch.Tensor,
|
||||
cos_angle: bool = False,
|
||||
cos_bound: float = 1e-4,
|
||||
eps: float = 1e-4,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Calculates the relative angle (in radians) between pairs of
|
||||
@@ -33,7 +34,8 @@ def so3_relative_angle(
|
||||
of the `acos` call. Note that the non-finite outputs/gradients
|
||||
are returned when the angle is requested (i.e. `cos_angle==False`)
|
||||
and the rotation angle is close to 0 or π.
|
||||
|
||||
eps: Tolerance for the valid trace check of the relative rotation matrix
|
||||
in `so3_rotation_angle`.
|
||||
Returns:
|
||||
Corresponding rotation angles of shape `(minibatch,)`.
|
||||
If `cos_angle==True`, returns the cosine of the angles.
|
||||
@@ -43,7 +45,7 @@ def so3_relative_angle(
|
||||
ValueError if `R1` or `R2` has an unexpected trace.
|
||||
"""
|
||||
R12 = torch.bmm(R1, R2.permute(0, 2, 1))
|
||||
return so3_rotation_angle(R12, cos_angle=cos_angle, cos_bound=cos_bound)
|
||||
return so3_rotation_angle(R12, cos_angle=cos_angle, cos_bound=cos_bound, eps=eps)
|
||||
|
||||
|
||||
def so3_rotation_angle(
|
||||
|
||||
Reference in New Issue
Block a user