mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
Initialization of Transform3D with a custom matrix.
Summary: Allows to initialize a Transform3D object with a batch of user-defined transformation matrices: ``` t = Transform3D(matrix=torch.randn(2, 4, 4)) ``` Reviewed By: nikhilaravi Differential Revision: D20693475 fbshipit-source-id: dccc49b2ca4c19a034844c63463953ba8f52c1bc
This commit is contained in:
committed by
Facebook GitHub Bot
parent
e37085d999
commit
90dc7a0856
@@ -57,6 +57,19 @@ class TestTransform(unittest.TestCase):
|
||||
matrix2 = t_pair[1].get_matrix()
|
||||
self.assertTrue(torch.allclose(matrix1, matrix2))
|
||||
|
||||
def test_init_with_custom_matrix(self):
|
||||
for matrix in (torch.randn(10, 4, 4), torch.randn(4, 4)):
|
||||
t = Transform3d(matrix=matrix)
|
||||
self.assertTrue(t.device == matrix.device)
|
||||
self.assertTrue(t._matrix.dtype == matrix.dtype)
|
||||
self.assertTrue(torch.allclose(t._matrix, matrix.view(t._matrix.shape)))
|
||||
|
||||
def test_init_with_custom_matrix_errors(self):
|
||||
bad_shapes = [[10, 5, 4], [3, 4], [10, 4, 4, 1], [10, 4, 4, 2], [4, 4, 4, 3]]
|
||||
for bad_shape in bad_shapes:
|
||||
matrix = torch.randn(*bad_shape).float()
|
||||
self.assertRaises(ValueError, Transform3d, matrix=matrix)
|
||||
|
||||
def test_translate(self):
|
||||
t = Transform3d().translate(1, 2, 3)
|
||||
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
|
||||
|
||||
Reference in New Issue
Block a user