diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index 180f341f..580c02b6 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -390,16 +390,24 @@ class Transform3d: return normals_out def translate(self, *args, **kwargs) -> "Transform3d": - return self.compose(Translate(device=self.device, *args, **kwargs)) + return self.compose( + Translate(device=self.device, dtype=self.dtype, *args, **kwargs) + ) def scale(self, *args, **kwargs) -> "Transform3d": - return self.compose(Scale(device=self.device, *args, **kwargs)) + return self.compose( + Scale(device=self.device, dtype=self.dtype, *args, **kwargs) + ) def rotate(self, *args, **kwargs) -> "Transform3d": - return self.compose(Rotate(device=self.device, *args, **kwargs)) + return self.compose( + Rotate(device=self.device, dtype=self.dtype, *args, **kwargs) + ) def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d": - return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs)) + return self.compose( + RotateAxisAngle(device=self.device, dtype=self.dtype, *args, **kwargs) + ) def clone(self) -> "Transform3d": """ @@ -488,7 +496,7 @@ class Translate(Transform3d): - A 1D torch tensor """ xyz = _handle_input(x, y, z, dtype, device, "Translate") - super().__init__(device=xyz.device) + super().__init__(device=xyz.device, dtype=dtype) N = xyz.shape[0] mat = torch.eye(4, dtype=dtype, device=self.device) @@ -532,7 +540,7 @@ class Scale(Transform3d): - 1D torch tensor """ xyz = _handle_input(x, y, z, dtype, device, "scale", allow_singleton=True) - super().__init__(device=xyz.device) + super().__init__(device=xyz.device, dtype=dtype) N = xyz.shape[0] # TODO: Can we do this all in one go somehow? @@ -571,7 +579,7 @@ class Rotate(Transform3d): """ device_ = get_device(R, device) - super().__init__(device=device_) + super().__init__(device=device_, dtype=dtype) if R.dim() == 2: R = R[None] if R.shape[-2:] != (3, 3): @@ -598,7 +606,7 @@ class RotateAxisAngle(Rotate): angle, axis: str = "X", degrees: bool = True, - dtype: torch.dtype = torch.float64, + dtype: torch.dtype = torch.float32, device: Optional[Device] = None, ) -> None: """ @@ -629,7 +637,7 @@ class RotateAxisAngle(Rotate): # is for transforming column vectors. Therefore we transpose this matrix. # R will always be of shape (N, 3, 3) R = _axis_angle_rotation(axis, angle).transpose(1, 2) - super().__init__(device=angle.device, R=R) + super().__init__(device=angle.device, R=R, dtype=dtype) def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor: @@ -646,8 +654,8 @@ def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor: c = torch.tensor(c, dtype=dtype, device=device) if c.dim() == 0: c = c.view(1) - if c.device != device: - c = c.to(device=device) + if c.device != device or c.dtype != dtype: + c = c.to(device=device, dtype=dtype) return c @@ -696,7 +704,7 @@ def _handle_input( if y is not None or z is not None: msg = "Expected y and z to be None (in %s)" % name raise ValueError(msg) - return x.to(device=device_) + return x.to(device=device_, dtype=dtype) if allow_singleton and y is None and z is None: y = x diff --git a/tests/test_transforms.py b/tests/test_transforms.py index f4690a41..37a8d0df 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -87,6 +87,36 @@ class TestTransform(TestCaseMixin, unittest.TestCase): t = t.cuda() t = t.cpu() + def test_dtype_propagation(self): + """ + Check that a given dtype is correctly passed along to child + transformations. + """ + # Use at least two dtypes so we avoid only testing on the + # default dtype. + for dtype in [torch.float32, torch.float64]: + R = torch.tensor( + [[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]], + dtype=dtype, + ) + tf = ( + Transform3d(dtype=dtype) + .rotate(R) + .rotate_axis_angle( + R[0], + "X", + ) + .translate(3, 2, 1) + .scale(0.5) + ) + + self.assertEqual(tf.dtype, dtype) + for inner_tf in tf._transforms: + self.assertEqual(inner_tf.dtype, dtype) + + transformed = tf.transform_points(R) + self.assertEqual(transformed.dtype, dtype) + def test_clone(self): """ Check that cloned transformations contain different _matrix objects. @@ -219,8 +249,8 @@ class TestTransform(TestCaseMixin, unittest.TestCase): normals_out_expected = torch.tensor( [[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]] ).view(1, 3, 3) - self.assertTrue(torch.allclose(points_out, points_out_expected)) - self.assertTrue(torch.allclose(normals_out, normals_out_expected)) + self.assertTrue(torch.allclose(points_out, points_out_expected, atol=1e-7)) + self.assertTrue(torch.allclose(normals_out, normals_out_expected, atol=1e-7)) def test_transform_points_fail(self): t1 = Scale(0.1, 0.1, 0.1) @@ -951,7 +981,7 @@ class TestRotateAxisAngle(unittest.TestCase): self.assertTrue( torch.allclose(transformed_points.squeeze(), expected_points, atol=1e-7) ) - self.assertTrue(torch.allclose(t._matrix, matrix)) + self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7)) def test_rotate_x_torch_scalar(self): angle = torch.tensor(90.0)