mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
Summary: Remove the compat functions eigh, solve, lstsq, and qr. Migrate callers to use torch.linalg directly. Reviewed By: bottler Differential Revision: D39172949 fbshipit-source-id: 484230a553237808f06ee5cdfde64651cba91c4c
222 lines
7.8 KiB
Python
222 lines
7.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import torch
|
|
|
|
from .so3 import _so3_exp_map, hat, so3_log_map
|
|
|
|
|
|
def se3_exp_map(log_transform: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
|
|
"""
|
|
Convert a batch of logarithmic representations of SE(3) matrices `log_transform`
|
|
to a batch of 4x4 SE(3) matrices using the exponential map.
|
|
See e.g. [1], Sec 9.4.2. for more detailed description.
|
|
|
|
A SE(3) matrix has the following form:
|
|
```
|
|
[ R 0 ]
|
|
[ T 1 ] ,
|
|
```
|
|
where `R` is a 3x3 rotation matrix and `T` is a 3-D translation vector.
|
|
SE(3) matrices are commonly used to represent rigid motions or camera extrinsics.
|
|
|
|
In the SE(3) logarithmic representation SE(3) matrices are
|
|
represented as 6-dimensional vectors `[log_translation | log_rotation]`,
|
|
i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`.
|
|
|
|
The conversion from the 6D representation to a 4x4 SE(3) matrix `transform`
|
|
is done as follows:
|
|
```
|
|
transform = exp( [ hat(log_rotation) 0 ]
|
|
[ log_translation 1 ] ) ,
|
|
```
|
|
where `exp` is the matrix exponential and `hat` is the Hat operator [2].
|
|
|
|
Note that for any `log_transform` with `0 <= ||log_rotation|| < 2pi`
|
|
(i.e. the rotation angle is between 0 and 2pi), the following identity holds:
|
|
```
|
|
se3_log_map(se3_exponential_map(log_transform)) == log_transform
|
|
```
|
|
|
|
The conversion has a singularity around `||log(transform)|| = 0`
|
|
which is handled by clamping controlled with the `eps` argument.
|
|
|
|
Args:
|
|
log_transform: Batch of vectors of shape `(minibatch, 6)`.
|
|
eps: A threshold for clipping the squared norm of the rotation logarithm
|
|
to avoid unstable gradients in the singular case.
|
|
|
|
Returns:
|
|
Batch of transformation matrices of shape `(minibatch, 4, 4)`.
|
|
|
|
Raises:
|
|
ValueError if `log_transform` is of incorrect shape.
|
|
|
|
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
|
|
[2] https://en.wikipedia.org/wiki/Hat_operator
|
|
"""
|
|
|
|
if log_transform.ndim != 2 or log_transform.shape[1] != 6:
|
|
raise ValueError("Expected input to be of shape (N, 6).")
|
|
|
|
N, _ = log_transform.shape
|
|
|
|
log_translation = log_transform[..., :3]
|
|
log_rotation = log_transform[..., 3:]
|
|
|
|
# rotation is an exponential map of log_rotation
|
|
(
|
|
R,
|
|
rotation_angles,
|
|
log_rotation_hat,
|
|
log_rotation_hat_square,
|
|
) = _so3_exp_map(log_rotation, eps=eps)
|
|
|
|
# translation is V @ T
|
|
V = _se3_V_matrix(
|
|
log_rotation,
|
|
log_rotation_hat,
|
|
log_rotation_hat_square,
|
|
rotation_angles,
|
|
eps=eps,
|
|
)
|
|
T = torch.bmm(V, log_translation[:, :, None])[:, :, 0]
|
|
|
|
transform = torch.zeros(
|
|
N, 4, 4, dtype=log_transform.dtype, device=log_transform.device
|
|
)
|
|
|
|
transform[:, :3, :3] = R
|
|
transform[:, :3, 3] = T
|
|
transform[:, 3, 3] = 1.0
|
|
|
|
return transform.permute(0, 2, 1)
|
|
|
|
|
|
def se3_log_map(
|
|
transform: torch.Tensor, eps: float = 1e-4, cos_bound: float = 1e-4
|
|
) -> torch.Tensor:
|
|
"""
|
|
Convert a batch of 4x4 transformation matrices `transform`
|
|
to a batch of 6-dimensional SE(3) logarithms of the SE(3) matrices.
|
|
See e.g. [1], Sec 9.4.2. for more detailed description.
|
|
|
|
A SE(3) matrix has the following form:
|
|
```
|
|
[ R 0 ]
|
|
[ T 1 ] ,
|
|
```
|
|
where `R` is an orthonormal 3x3 rotation matrix and `T` is a 3-D translation vector.
|
|
SE(3) matrices are commonly used to represent rigid motions or camera extrinsics.
|
|
|
|
In the SE(3) logarithmic representation SE(3) matrices are
|
|
represented as 6-dimensional vectors `[log_translation | log_rotation]`,
|
|
i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`.
|
|
|
|
The conversion from the 4x4 SE(3) matrix `transform` to the
|
|
6D representation `log_transform = [log_translation | log_rotation]`
|
|
is done as follows:
|
|
```
|
|
log_transform = log(transform)
|
|
log_translation = log_transform[3, :3]
|
|
log_rotation = inv_hat(log_transform[:3, :3])
|
|
```
|
|
where `log` is the matrix logarithm
|
|
and `inv_hat` is the inverse of the Hat operator [2].
|
|
|
|
Note that for any valid 4x4 `transform` matrix, the following identity holds:
|
|
```
|
|
se3_exp_map(se3_log_map(transform)) == transform
|
|
```
|
|
|
|
The conversion has a singularity around `(transform=I)` which is handled
|
|
by clamping controlled with the `eps` and `cos_bound` arguments.
|
|
|
|
Args:
|
|
transform: batch of SE(3) matrices of shape `(minibatch, 4, 4)`.
|
|
eps: A threshold for clipping the squared norm of the rotation logarithm
|
|
to avoid division by zero in the singular case.
|
|
cos_bound: Clamps the cosine of the rotation angle to
|
|
[-1 + cos_bound, 3 - cos_bound] to avoid non-finite outputs.
|
|
The non-finite outputs can be caused by passing small rotation angles
|
|
to the `acos` function in `so3_rotation_angle` of `so3_log_map`.
|
|
|
|
Returns:
|
|
Batch of logarithms of input SE(3) matrices
|
|
of shape `(minibatch, 6)`.
|
|
|
|
Raises:
|
|
ValueError if `transform` is of incorrect shape.
|
|
ValueError if `R` has an unexpected trace.
|
|
|
|
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
|
|
[2] https://en.wikipedia.org/wiki/Hat_operator
|
|
"""
|
|
|
|
if transform.ndim != 3:
|
|
raise ValueError("Input tensor shape has to be (N, 4, 4).")
|
|
|
|
N, dim1, dim2 = transform.shape
|
|
if dim1 != 4 or dim2 != 4:
|
|
raise ValueError("Input tensor shape has to be (N, 4, 4).")
|
|
|
|
if not torch.allclose(transform[:, :3, 3], torch.zeros_like(transform[:, :3, 3])):
|
|
raise ValueError("All elements of `transform[:, :3, 3]` should be 0.")
|
|
|
|
# log_rot is just so3_log_map of the upper left 3x3 block
|
|
R = transform[:, :3, :3].permute(0, 2, 1)
|
|
log_rotation = so3_log_map(R, eps=eps, cos_bound=cos_bound)
|
|
|
|
# log_translation is V^-1 @ T
|
|
T = transform[:, 3, :3]
|
|
V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps)
|
|
log_translation = torch.linalg.solve(V, T[:, :, None])[:, :, 0]
|
|
|
|
return torch.cat((log_translation, log_rotation), dim=1)
|
|
|
|
|
|
def _se3_V_matrix(
|
|
log_rotation: torch.Tensor,
|
|
log_rotation_hat: torch.Tensor,
|
|
log_rotation_hat_square: torch.Tensor,
|
|
rotation_angles: torch.Tensor,
|
|
eps: float = 1e-4,
|
|
) -> torch.Tensor:
|
|
"""
|
|
A helper function that computes the "V" matrix from [1], Sec 9.4.2.
|
|
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
|
|
"""
|
|
|
|
V = (
|
|
torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None]
|
|
+ log_rotation_hat
|
|
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
|
* ((1 - torch.cos(rotation_angles)) / (rotation_angles**2))[:, None, None]
|
|
+ (
|
|
log_rotation_hat_square
|
|
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
|
# `int`.
|
|
* ((rotation_angles - torch.sin(rotation_angles)) / (rotation_angles**3))[
|
|
:, None, None
|
|
]
|
|
)
|
|
)
|
|
|
|
return V
|
|
|
|
|
|
def _get_se3_V_input(log_rotation: torch.Tensor, eps: float = 1e-4):
|
|
"""
|
|
A helper function that computes the input variables to the `_se3_V_matrix`
|
|
function.
|
|
"""
|
|
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and `int`.
|
|
nrms = (log_rotation**2).sum(-1)
|
|
rotation_angles = torch.clamp(nrms, eps).sqrt()
|
|
log_rotation_hat = hat(log_rotation)
|
|
log_rotation_hat_square = torch.bmm(log_rotation_hat, log_rotation_hat)
|
|
return log_rotation, log_rotation_hat, log_rotation_hat_square, rotation_angles
|