diff --git a/pytorch3d/transforms/rotation_conversions.py b/pytorch3d/transforms/rotation_conversions.py index e322280d..9e41ed3e 100644 --- a/pytorch3d/transforms/rotation_conversions.py +++ b/pytorch3d/transforms/rotation_conversions.py @@ -82,6 +82,17 @@ def _copysign(a, b): return torch.where(signs_differ, -a, a) +def _sqrt_positive_part(x): + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + def matrix_to_quaternion(matrix): """ Convert rotations given as rotation matrices to quaternions. @@ -94,14 +105,13 @@ def matrix_to_quaternion(matrix): """ if matrix.size(-1) != 3 or matrix.size(-2) != 3: raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.") - zero = matrix.new_zeros((1,)) m00 = matrix[..., 0, 0] m11 = matrix[..., 1, 1] m22 = matrix[..., 2, 2] - o0 = 0.5 * torch.sqrt(torch.max(zero, 1 + m00 + m11 + m22)) - x = 0.5 * torch.sqrt(torch.max(zero, 1 + m00 - m11 - m22)) - y = 0.5 * torch.sqrt(torch.max(zero, 1 - m00 + m11 - m22)) - z = 0.5 * torch.sqrt(torch.max(zero, 1 - m00 - m11 + m22)) + o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) + x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) + y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) + z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) diff --git a/tests/test_rotation_conversions.py b/tests/test_rotation_conversions.py index f8bc60fb..74dcd1a0 100644 --- a/tests/test_rotation_conversions.py +++ b/tests/test_rotation_conversions.py @@ -145,6 +145,20 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase): self.assertEqual(ab.shape, ab_from_matrix.shape) self.assertTrue(torch.allclose(ab, ab_from_matrix)) + def test_matrix_to_quaternion_corner_case(self): + """Check no bad gradients from sqrt(0).""" + matrix = torch.eye(3, requires_grad=True) + target = torch.Tensor([0.984808, 0, 0.174, 0]) + + optimizer = torch.optim.Adam([matrix], lr=0.05) + optimizer.zero_grad() + q = matrix_to_quaternion(matrix) + loss = torch.sum((q - target) ** 2) + loss.backward() + optimizer.step() + + self.assertClose(matrix, 0.95 * torch.eye(3)) + def test_quaternion_application(self): """Applying a quaternion is the same as applying the matrix.""" quaternions = random_quaternions(3, torch.float64, requires_grad=True)