mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Extending the API of Transform3d with SE(3) log
Summary: This is quite a thin wrapper – not sure we need it. The motivation is that `Transform3d` is not as matrix-centric now, it can be converted to SE(3) logarithm equally easily. It simplifies things like averaging cameras and getting axis-angle of camera rotation (previously, one would need to call `se3_log_map(cameras.get_world_to_camera_transform().get_matrix())`), now one fewer thing to call / discover. Reviewed By: bottler Differential Revision: D39928000 fbshipit-source-id: 85248d5b8af136618f1d08791af5297ea5179d19
This commit is contained in:
		
							parent
							
								
									74bbd6fd76
								
							
						
					
					
						commit
						9a0f9ae572
					
				@ -13,6 +13,7 @@ import torch
 | 
			
		||||
from ..common.datatypes import Device, get_device, make_device
 | 
			
		||||
from ..common.workaround import _safe_det_3x3
 | 
			
		||||
from .rotation_conversions import _axis_angle_rotation
 | 
			
		||||
from .se3 import se3_log_map
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Transform3d:
 | 
			
		||||
@ -130,13 +131,13 @@ class Transform3d:
 | 
			
		||||
                [Tx,  Ty,  Tz,  1],
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
    To apply the transformation to points which are row vectors, the M matrix
 | 
			
		||||
    can be pre multiplied by the points:
 | 
			
		||||
    To apply the transformation to points, which are row vectors, the latter are
 | 
			
		||||
    converted to homogeneous (4D) coordinates and right-multiplied by the M matrix:
 | 
			
		||||
 | 
			
		||||
    .. code-block:: python
 | 
			
		||||
 | 
			
		||||
        points = [[0, 1, 2]]  # (1 x 3) xyz coordinates of a point
 | 
			
		||||
        transformed_points = points * M
 | 
			
		||||
        [transformed_points, 1] ∝ [points, 1] @ M
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -218,9 +219,10 @@ class Transform3d:
 | 
			
		||||
 | 
			
		||||
    def get_matrix(self) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Return a matrix which is the result of composing this transform
 | 
			
		||||
        with others stored in self.transforms. Where necessary transforms
 | 
			
		||||
        are broadcast against each other.
 | 
			
		||||
        Returns a 4×4 matrix corresponding to each transform in the batch.
 | 
			
		||||
 | 
			
		||||
        If the transform was composed from others, the matrix for the composite
 | 
			
		||||
        transform will be returned.
 | 
			
		||||
        For example, if self.transforms contains transforms t1, t2, and t3, and
 | 
			
		||||
        given a set of points x, the following should be true:
 | 
			
		||||
 | 
			
		||||
@ -230,8 +232,11 @@ class Transform3d:
 | 
			
		||||
            y2 = t3.transform(t2.transform(t1.transform(x)))
 | 
			
		||||
            y1.get_matrix() == y2.get_matrix()
 | 
			
		||||
 | 
			
		||||
        Where necessary, those transforms are broadcast against each other.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A transformation matrix representing the composed inputs.
 | 
			
		||||
            A (N, 4, 4) batch of transformation matrices representing
 | 
			
		||||
                the stored transforms. See the class documentation for the conventions.
 | 
			
		||||
        """
 | 
			
		||||
        composed_matrix = self._matrix.clone()
 | 
			
		||||
        if len(self._transforms) > 0:
 | 
			
		||||
@ -240,6 +245,49 @@ class Transform3d:
 | 
			
		||||
                composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
 | 
			
		||||
        return composed_matrix
 | 
			
		||||
 | 
			
		||||
    def get_se3_log(self, eps: float = 1e-4, cos_bound: float = 1e-4) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Returns a 6D SE(3) log vector corresponding to each transform in the batch.
 | 
			
		||||
 | 
			
		||||
        In the SE(3) logarithmic representation SE(3) matrices are
 | 
			
		||||
        represented as 6-dimensional vectors `[log_translation | log_rotation]`,
 | 
			
		||||
        i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`.
 | 
			
		||||
 | 
			
		||||
        The conversion from the 4x4 SE(3) matrix `transform` to the
 | 
			
		||||
        6D representation `log_transform = [log_translation | log_rotation]`
 | 
			
		||||
        is done as follows:
 | 
			
		||||
            ```
 | 
			
		||||
            log_transform = log(transform.get_matrix())
 | 
			
		||||
            log_translation = log_transform[3, :3]
 | 
			
		||||
            log_rotation = inv_hat(log_transform[:3, :3])
 | 
			
		||||
            ```
 | 
			
		||||
        where `log` is the matrix logarithm
 | 
			
		||||
        and `inv_hat` is the inverse of the Hat operator [2].
 | 
			
		||||
 | 
			
		||||
        See the docstring for `se3.se3_log_map` and [1], Sec 9.4.2. for more
 | 
			
		||||
        detailed description.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            eps: A threshold for clipping the squared norm of the rotation logarithm
 | 
			
		||||
                to avoid division by zero in the singular case.
 | 
			
		||||
            cos_bound: Clamps the cosine of the rotation angle to
 | 
			
		||||
                [-1 + cos_bound, 3 - cos_bound] to avoid non-finite outputs.
 | 
			
		||||
                The non-finite outputs can be caused by passing small rotation angles
 | 
			
		||||
                to the `acos` function in `so3_rotation_angle` of `so3_log_map`.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            A (N, 6) tensor, rows of which represent the individual transforms
 | 
			
		||||
            stored in the object as SE(3) logarithms.
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            ValueError if the stored transform is not Euclidean (e.g. R is not a rotation
 | 
			
		||||
                matrix or the last column has non-zeros in the first three places).
 | 
			
		||||
 | 
			
		||||
        [1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
 | 
			
		||||
        [2] https://en.wikipedia.org/wiki/Hat_operator
 | 
			
		||||
        """
 | 
			
		||||
        return se3_log_map(self.get_matrix(), eps, cos_bound)
 | 
			
		||||
 | 
			
		||||
    def _get_matrix_inverse(self) -> torch.Tensor:
 | 
			
		||||
        """
 | 
			
		||||
        Return the inverse of self._matrix.
 | 
			
		||||
 | 
			
		||||
@ -10,6 +10,7 @@ import unittest
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.transforms import random_rotations
 | 
			
		||||
from pytorch3d.transforms.se3 import se3_log_map
 | 
			
		||||
from pytorch3d.transforms.so3 import so3_exp_map
 | 
			
		||||
from pytorch3d.transforms.transform3d import (
 | 
			
		||||
    Rotate,
 | 
			
		||||
@ -161,6 +162,16 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            matrix = torch.randn(*bad_shape).float()
 | 
			
		||||
            self.assertRaises(ValueError, Transform3d, matrix=matrix)
 | 
			
		||||
 | 
			
		||||
    def test_get_se3(self):
 | 
			
		||||
        N = 16
 | 
			
		||||
        random_rotations(N)
 | 
			
		||||
        tr = Translate(torch.rand((N, 3)))
 | 
			
		||||
        R = Rotate(random_rotations(N))
 | 
			
		||||
        transform = Transform3d().compose(R, tr)
 | 
			
		||||
        se3_log = transform.get_se3_log()
 | 
			
		||||
        gt_se3_log = se3_log_map(transform.get_matrix())
 | 
			
		||||
        self.assertClose(se3_log, gt_se3_log)
 | 
			
		||||
 | 
			
		||||
    def test_translate(self):
 | 
			
		||||
        t = Transform3d().translate(1, 2, 3)
 | 
			
		||||
        points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user