mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Fix dtype propagation (#1141)
Summary: Previously, dtypes were not propagated correctly in composed transforms, resulting in errors when different dtypes were mixed. Even specifying a dtype in the constructor does not fix this. Neither does specifying the dtype for each composition function invocation (e.g. as a `kwarg` in `rotate_axis_angle`). With the change, I also had to modify the default dtype of `RotateAxisAngle`, which was `torch.float64`; it is now `torch.float32` like for all other transforms. This was required because the fix in propagation broke some tests due to dtype mismatches. This change in default dtype in turn broke two tests due to precision changes (calculations that were previously done in `torch.float64` were now done in `torch.float32`), so I changed the precision tolerances to be less strict. I chose the lowest power of ten that passed the tests here. Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1141 Reviewed By: patricklabatut Differential Revision: D35192970 Pulled By: bottler fbshipit-source-id: ba0293e8b3595dfc94b3cf8048e50b7a5e5ed7cf
This commit is contained in:
parent
21262e38c7
commit
b602edccc4
@ -390,16 +390,24 @@ class Transform3d:
|
||||
return normals_out
|
||||
|
||||
def translate(self, *args, **kwargs) -> "Transform3d":
|
||||
return self.compose(Translate(device=self.device, *args, **kwargs))
|
||||
return self.compose(
|
||||
Translate(device=self.device, dtype=self.dtype, *args, **kwargs)
|
||||
)
|
||||
|
||||
def scale(self, *args, **kwargs) -> "Transform3d":
|
||||
return self.compose(Scale(device=self.device, *args, **kwargs))
|
||||
return self.compose(
|
||||
Scale(device=self.device, dtype=self.dtype, *args, **kwargs)
|
||||
)
|
||||
|
||||
def rotate(self, *args, **kwargs) -> "Transform3d":
|
||||
return self.compose(Rotate(device=self.device, *args, **kwargs))
|
||||
return self.compose(
|
||||
Rotate(device=self.device, dtype=self.dtype, *args, **kwargs)
|
||||
)
|
||||
|
||||
def rotate_axis_angle(self, *args, **kwargs) -> "Transform3d":
|
||||
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))
|
||||
return self.compose(
|
||||
RotateAxisAngle(device=self.device, dtype=self.dtype, *args, **kwargs)
|
||||
)
|
||||
|
||||
def clone(self) -> "Transform3d":
|
||||
"""
|
||||
@ -488,7 +496,7 @@ class Translate(Transform3d):
|
||||
- A 1D torch tensor
|
||||
"""
|
||||
xyz = _handle_input(x, y, z, dtype, device, "Translate")
|
||||
super().__init__(device=xyz.device)
|
||||
super().__init__(device=xyz.device, dtype=dtype)
|
||||
N = xyz.shape[0]
|
||||
|
||||
mat = torch.eye(4, dtype=dtype, device=self.device)
|
||||
@ -532,7 +540,7 @@ class Scale(Transform3d):
|
||||
- 1D torch tensor
|
||||
"""
|
||||
xyz = _handle_input(x, y, z, dtype, device, "scale", allow_singleton=True)
|
||||
super().__init__(device=xyz.device)
|
||||
super().__init__(device=xyz.device, dtype=dtype)
|
||||
N = xyz.shape[0]
|
||||
|
||||
# TODO: Can we do this all in one go somehow?
|
||||
@ -571,7 +579,7 @@ class Rotate(Transform3d):
|
||||
|
||||
"""
|
||||
device_ = get_device(R, device)
|
||||
super().__init__(device=device_)
|
||||
super().__init__(device=device_, dtype=dtype)
|
||||
if R.dim() == 2:
|
||||
R = R[None]
|
||||
if R.shape[-2:] != (3, 3):
|
||||
@ -598,7 +606,7 @@ class RotateAxisAngle(Rotate):
|
||||
angle,
|
||||
axis: str = "X",
|
||||
degrees: bool = True,
|
||||
dtype: torch.dtype = torch.float64,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
device: Optional[Device] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -629,7 +637,7 @@ class RotateAxisAngle(Rotate):
|
||||
# is for transforming column vectors. Therefore we transpose this matrix.
|
||||
# R will always be of shape (N, 3, 3)
|
||||
R = _axis_angle_rotation(axis, angle).transpose(1, 2)
|
||||
super().__init__(device=angle.device, R=R)
|
||||
super().__init__(device=angle.device, R=R, dtype=dtype)
|
||||
|
||||
|
||||
def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
||||
@ -646,8 +654,8 @@ def _handle_coord(c, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
|
||||
c = torch.tensor(c, dtype=dtype, device=device)
|
||||
if c.dim() == 0:
|
||||
c = c.view(1)
|
||||
if c.device != device:
|
||||
c = c.to(device=device)
|
||||
if c.device != device or c.dtype != dtype:
|
||||
c = c.to(device=device, dtype=dtype)
|
||||
return c
|
||||
|
||||
|
||||
@ -696,7 +704,7 @@ def _handle_input(
|
||||
if y is not None or z is not None:
|
||||
msg = "Expected y and z to be None (in %s)" % name
|
||||
raise ValueError(msg)
|
||||
return x.to(device=device_)
|
||||
return x.to(device=device_, dtype=dtype)
|
||||
|
||||
if allow_singleton and y is None and z is None:
|
||||
y = x
|
||||
|
@ -87,6 +87,36 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
||||
t = t.cuda()
|
||||
t = t.cpu()
|
||||
|
||||
def test_dtype_propagation(self):
|
||||
"""
|
||||
Check that a given dtype is correctly passed along to child
|
||||
transformations.
|
||||
"""
|
||||
# Use at least two dtypes so we avoid only testing on the
|
||||
# default dtype.
|
||||
for dtype in [torch.float32, torch.float64]:
|
||||
R = torch.tensor(
|
||||
[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]],
|
||||
dtype=dtype,
|
||||
)
|
||||
tf = (
|
||||
Transform3d(dtype=dtype)
|
||||
.rotate(R)
|
||||
.rotate_axis_angle(
|
||||
R[0],
|
||||
"X",
|
||||
)
|
||||
.translate(3, 2, 1)
|
||||
.scale(0.5)
|
||||
)
|
||||
|
||||
self.assertEqual(tf.dtype, dtype)
|
||||
for inner_tf in tf._transforms:
|
||||
self.assertEqual(inner_tf.dtype, dtype)
|
||||
|
||||
transformed = tf.transform_points(R)
|
||||
self.assertEqual(transformed.dtype, dtype)
|
||||
|
||||
def test_clone(self):
|
||||
"""
|
||||
Check that cloned transformations contain different _matrix objects.
|
||||
@ -219,8 +249,8 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
||||
normals_out_expected = torch.tensor(
|
||||
[[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]]
|
||||
).view(1, 3, 3)
|
||||
self.assertTrue(torch.allclose(points_out, points_out_expected))
|
||||
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
|
||||
self.assertTrue(torch.allclose(points_out, points_out_expected, atol=1e-7))
|
||||
self.assertTrue(torch.allclose(normals_out, normals_out_expected, atol=1e-7))
|
||||
|
||||
def test_transform_points_fail(self):
|
||||
t1 = Scale(0.1, 0.1, 0.1)
|
||||
@ -951,7 +981,7 @@ class TestRotateAxisAngle(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
torch.allclose(transformed_points.squeeze(), expected_points, atol=1e-7)
|
||||
)
|
||||
self.assertTrue(torch.allclose(t._matrix, matrix))
|
||||
self.assertTrue(torch.allclose(t._matrix, matrix, atol=1e-7))
|
||||
|
||||
def test_rotate_x_torch_scalar(self):
|
||||
angle = torch.tensor(90.0)
|
||||
|
Loading…
x
Reference in New Issue
Block a user