Extending the API of Transform3d with SE(3) log

Summary:
This is quite a thin wrapper – not sure we need it. The motivation is that `Transform3d` is not as matrix-centric now, it can be converted to SE(3) logarithm equally easily.

It simplifies things like averaging cameras and getting axis-angle of camera rotation (previously, one would need to call `se3_log_map(cameras.get_world_to_camera_transform().get_matrix())`), now one fewer thing to call / discover.

Reviewed By: bottler

Differential Revision: D39928000

fbshipit-source-id: 85248d5b8af136618f1d08791af5297ea5179d19
This commit is contained in:
Roman Shapovalov
2022-09-29 11:56:14 -07:00
committed by Facebook GitHub Bot
parent 74bbd6fd76
commit 9a0f9ae572
2 changed files with 66 additions and 7 deletions

View File

@@ -10,6 +10,7 @@ import unittest
import torch
from pytorch3d.transforms import random_rotations
from pytorch3d.transforms.se3 import se3_log_map
from pytorch3d.transforms.so3 import so3_exp_map
from pytorch3d.transforms.transform3d import (
Rotate,
@@ -161,6 +162,16 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
matrix = torch.randn(*bad_shape).float()
self.assertRaises(ValueError, Transform3d, matrix=matrix)
def test_get_se3(self):
N = 16
random_rotations(N)
tr = Translate(torch.rand((N, 3)))
R = Rotate(random_rotations(N))
transform = Transform3d().compose(R, tr)
se3_log = transform.get_se3_log()
gt_se3_log = se3_log_map(transform.get_matrix())
self.assertClose(se3_log, gt_se3_log)
def test_translate(self):
t = Transform3d().translate(1, 2, 3)
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(