diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index 066cfeac..b4140f29 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -282,9 +282,7 @@ def matrix_to_euler_angles(matrix, convention: str): return torch.stack(o, -1) -def random_quaternions( - n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False -): +def random_quaternions(n: int, dtype: Optional[torch.dtype] = None, device=None): """ Generate random quaternions representing rotations, i.e. versors with nonnegative real part. @@ -294,21 +292,17 @@ def random_quaternions( dtype: Type to return. device: Desired device of returned tensor. Default: uses the current device for the default tensor type. - requires_grad: Whether the resulting tensor should have the gradient - flag set. Returns: Quaternions as tensor of shape (N, 4). """ - o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad) + o = torch.randn((n, 4), dtype=dtype, device=device) s = (o * o).sum(1) o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None] return o -def random_rotations( - n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False -): +def random_rotations(n: int, dtype: Optional[torch.dtype] = None, device=None): """ Generate random rotations as 3x3 rotation matrices. @@ -317,21 +311,15 @@ def random_rotations( dtype: Type to return. device: Device of returned tensor. Default: if None, uses the current device for the default tensor type. - requires_grad: Whether the resulting tensor should have the gradient - flag set. Returns: Rotation matrices as tensor of shape (n, 3, 3). """ - quaternions = random_quaternions( - n, dtype=dtype, device=device, requires_grad=requires_grad - ) + quaternions = random_quaternions(n, dtype=dtype, device=device) return quaternion_to_matrix(quaternions) -def random_rotation( - dtype: Optional[torch.dtype] = None, device=None, requires_grad=False -): +def random_rotation(dtype: Optional[torch.dtype] = None, device=None): """ Generate a single random 3x3 rotation matrix. @@ -339,13 +327,11 @@ def random_rotation( dtype: Type to return device: Device of returned tensor. Default: if None, uses the current device for the default tensor type - requires_grad: Whether the resulting tensor should have the gradient - flag set Returns: Rotation matrix as tensor of shape (3, 3). """ - return random_rotations(1, dtype, device, requires_grad)[0] + return random_rotations(1, dtype, device)[0] def standardize_quaternion(quaternions): diff --git a/tests/test_rotation_conversions.py b/tests/test_rotation_conversions.py index 9875d01a..afcf9c47 100644 --- a/tests/test_rotation_conversions.py +++ b/tests/test_rotation_conversions.py @@ -76,7 +76,8 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase): def test_quat_grad_exists(self): """Quaternion calculations are differentiable.""" - rotation = random_rotation(requires_grad=True) + rotation = random_rotation() + rotation.requires_grad = True modified = quaternion_to_matrix(matrix_to_quaternion(rotation)) [g] = torch.autograd.grad(modified.sum(), rotation) self.assertTrue(torch.isfinite(g).all()) @@ -131,7 +132,8 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase): def test_euler_grad_exists(self): """Euler angle calculations are differentiable.""" - rotation = random_rotation(dtype=torch.float64, requires_grad=True) + rotation = random_rotation(dtype=torch.float64) + rotation.requires_grad = True for convention in self._all_euler_angle_conventions(): euler_angles = matrix_to_euler_angles(rotation, convention) mdata = euler_angles_to_matrix(euler_angles, convention) @@ -218,7 +220,8 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase): def test_quaternion_application(self): """Applying a quaternion is the same as applying the matrix.""" - quaternions = random_quaternions(3, torch.float64, requires_grad=True) + quaternions = random_quaternions(3, torch.float64) + quaternions.requires_grad = True matrices = quaternion_to_matrix(quaternions) points = torch.randn(3, 3, dtype=torch.float64, requires_grad=True) transform1 = quaternion_apply(quaternions, points)