mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Extending the API of Transform3d with SE(3) log
Summary: This is quite a thin wrapper – not sure we need it. The motivation is that `Transform3d` is not as matrix-centric now, it can be converted to SE(3) logarithm equally easily. It simplifies things like averaging cameras and getting axis-angle of camera rotation (previously, one would need to call `se3_log_map(cameras.get_world_to_camera_transform().get_matrix())`), now one fewer thing to call / discover. Reviewed By: bottler Differential Revision: D39928000 fbshipit-source-id: 85248d5b8af136618f1d08791af5297ea5179d19
This commit is contained in:
parent
74bbd6fd76
commit
9a0f9ae572
@ -13,6 +13,7 @@ import torch
|
|||||||
from ..common.datatypes import Device, get_device, make_device
|
from ..common.datatypes import Device, get_device, make_device
|
||||||
from ..common.workaround import _safe_det_3x3
|
from ..common.workaround import _safe_det_3x3
|
||||||
from .rotation_conversions import _axis_angle_rotation
|
from .rotation_conversions import _axis_angle_rotation
|
||||||
|
from .se3 import se3_log_map
|
||||||
|
|
||||||
|
|
||||||
class Transform3d:
|
class Transform3d:
|
||||||
@ -130,13 +131,13 @@ class Transform3d:
|
|||||||
[Tx, Ty, Tz, 1],
|
[Tx, Ty, Tz, 1],
|
||||||
]
|
]
|
||||||
|
|
||||||
To apply the transformation to points which are row vectors, the M matrix
|
To apply the transformation to points, which are row vectors, the latter are
|
||||||
can be pre multiplied by the points:
|
converted to homogeneous (4D) coordinates and right-multiplied by the M matrix:
|
||||||
|
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
|
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
|
||||||
transformed_points = points * M
|
[transformed_points, 1] ∝ [points, 1] @ M
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -218,9 +219,10 @@ class Transform3d:
|
|||||||
|
|
||||||
def get_matrix(self) -> torch.Tensor:
|
def get_matrix(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Return a matrix which is the result of composing this transform
|
Returns a 4×4 matrix corresponding to each transform in the batch.
|
||||||
with others stored in self.transforms. Where necessary transforms
|
|
||||||
are broadcast against each other.
|
If the transform was composed from others, the matrix for the composite
|
||||||
|
transform will be returned.
|
||||||
For example, if self.transforms contains transforms t1, t2, and t3, and
|
For example, if self.transforms contains transforms t1, t2, and t3, and
|
||||||
given a set of points x, the following should be true:
|
given a set of points x, the following should be true:
|
||||||
|
|
||||||
@ -230,8 +232,11 @@ class Transform3d:
|
|||||||
y2 = t3.transform(t2.transform(t1.transform(x)))
|
y2 = t3.transform(t2.transform(t1.transform(x)))
|
||||||
y1.get_matrix() == y2.get_matrix()
|
y1.get_matrix() == y2.get_matrix()
|
||||||
|
|
||||||
|
Where necessary, those transforms are broadcast against each other.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A transformation matrix representing the composed inputs.
|
A (N, 4, 4) batch of transformation matrices representing
|
||||||
|
the stored transforms. See the class documentation for the conventions.
|
||||||
"""
|
"""
|
||||||
composed_matrix = self._matrix.clone()
|
composed_matrix = self._matrix.clone()
|
||||||
if len(self._transforms) > 0:
|
if len(self._transforms) > 0:
|
||||||
@ -240,6 +245,49 @@ class Transform3d:
|
|||||||
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
|
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
|
||||||
return composed_matrix
|
return composed_matrix
|
||||||
|
|
||||||
|
def get_se3_log(self, eps: float = 1e-4, cos_bound: float = 1e-4) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Returns a 6D SE(3) log vector corresponding to each transform in the batch.
|
||||||
|
|
||||||
|
In the SE(3) logarithmic representation SE(3) matrices are
|
||||||
|
represented as 6-dimensional vectors `[log_translation | log_rotation]`,
|
||||||
|
i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`.
|
||||||
|
|
||||||
|
The conversion from the 4x4 SE(3) matrix `transform` to the
|
||||||
|
6D representation `log_transform = [log_translation | log_rotation]`
|
||||||
|
is done as follows:
|
||||||
|
```
|
||||||
|
log_transform = log(transform.get_matrix())
|
||||||
|
log_translation = log_transform[3, :3]
|
||||||
|
log_rotation = inv_hat(log_transform[:3, :3])
|
||||||
|
```
|
||||||
|
where `log` is the matrix logarithm
|
||||||
|
and `inv_hat` is the inverse of the Hat operator [2].
|
||||||
|
|
||||||
|
See the docstring for `se3.se3_log_map` and [1], Sec 9.4.2. for more
|
||||||
|
detailed description.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
eps: A threshold for clipping the squared norm of the rotation logarithm
|
||||||
|
to avoid division by zero in the singular case.
|
||||||
|
cos_bound: Clamps the cosine of the rotation angle to
|
||||||
|
[-1 + cos_bound, 3 - cos_bound] to avoid non-finite outputs.
|
||||||
|
The non-finite outputs can be caused by passing small rotation angles
|
||||||
|
to the `acos` function in `so3_rotation_angle` of `so3_log_map`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A (N, 6) tensor, rows of which represent the individual transforms
|
||||||
|
stored in the object as SE(3) logarithms.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError if the stored transform is not Euclidean (e.g. R is not a rotation
|
||||||
|
matrix or the last column has non-zeros in the first three places).
|
||||||
|
|
||||||
|
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
|
||||||
|
[2] https://en.wikipedia.org/wiki/Hat_operator
|
||||||
|
"""
|
||||||
|
return se3_log_map(self.get_matrix(), eps, cos_bound)
|
||||||
|
|
||||||
def _get_matrix_inverse(self) -> torch.Tensor:
|
def _get_matrix_inverse(self) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Return the inverse of self._matrix.
|
Return the inverse of self._matrix.
|
||||||
|
@ -10,6 +10,7 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pytorch3d.transforms import random_rotations
|
from pytorch3d.transforms import random_rotations
|
||||||
|
from pytorch3d.transforms.se3 import se3_log_map
|
||||||
from pytorch3d.transforms.so3 import so3_exp_map
|
from pytorch3d.transforms.so3 import so3_exp_map
|
||||||
from pytorch3d.transforms.transform3d import (
|
from pytorch3d.transforms.transform3d import (
|
||||||
Rotate,
|
Rotate,
|
||||||
@ -161,6 +162,16 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
|||||||
matrix = torch.randn(*bad_shape).float()
|
matrix = torch.randn(*bad_shape).float()
|
||||||
self.assertRaises(ValueError, Transform3d, matrix=matrix)
|
self.assertRaises(ValueError, Transform3d, matrix=matrix)
|
||||||
|
|
||||||
|
def test_get_se3(self):
|
||||||
|
N = 16
|
||||||
|
random_rotations(N)
|
||||||
|
tr = Translate(torch.rand((N, 3)))
|
||||||
|
R = Rotate(random_rotations(N))
|
||||||
|
transform = Transform3d().compose(R, tr)
|
||||||
|
se3_log = transform.get_se3_log()
|
||||||
|
gt_se3_log = se3_log_map(transform.get_matrix())
|
||||||
|
self.assertClose(se3_log, gt_se3_log)
|
||||||
|
|
||||||
def test_translate(self):
|
def test_translate(self):
|
||||||
t = Transform3d().translate(1, 2, 3)
|
t = Transform3d().translate(1, 2, 3)
|
||||||
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
|
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user