mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
matrix_to_quaternion corner case
Summary: Issue #119. The function `sqrt(max(x, 0))` is not convex and has infinite gradient at 0, but 0 is a subgradient at 0. Here we implement it in such a way as to give 0 as the gradient. Reviewed By: gkioxari Differential Revision: D24306294 fbshipit-source-id: 48d136faca083babad4d64970be7ea522dbe9e09
This commit is contained in:
parent
2d39723610
commit
4d52f9fb8b
@ -82,6 +82,17 @@ def _copysign(a, b):
|
||||
return torch.where(signs_differ, -a, a)
|
||||
|
||||
|
||||
def _sqrt_positive_part(x):
|
||||
"""
|
||||
Returns torch.sqrt(torch.max(0, x))
|
||||
but with a zero subgradient where x is 0.
|
||||
"""
|
||||
ret = torch.zeros_like(x)
|
||||
positive_mask = x > 0
|
||||
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
||||
return ret
|
||||
|
||||
|
||||
def matrix_to_quaternion(matrix):
|
||||
"""
|
||||
Convert rotations given as rotation matrices to quaternions.
|
||||
@ -94,14 +105,13 @@ def matrix_to_quaternion(matrix):
|
||||
"""
|
||||
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
||||
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
||||
zero = matrix.new_zeros((1,))
|
||||
m00 = matrix[..., 0, 0]
|
||||
m11 = matrix[..., 1, 1]
|
||||
m22 = matrix[..., 2, 2]
|
||||
o0 = 0.5 * torch.sqrt(torch.max(zero, 1 + m00 + m11 + m22))
|
||||
x = 0.5 * torch.sqrt(torch.max(zero, 1 + m00 - m11 - m22))
|
||||
y = 0.5 * torch.sqrt(torch.max(zero, 1 - m00 + m11 - m22))
|
||||
z = 0.5 * torch.sqrt(torch.max(zero, 1 - m00 - m11 + m22))
|
||||
o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
|
||||
x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
|
||||
y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
|
||||
z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
|
||||
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
|
||||
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
|
||||
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
|
||||
|
@ -145,6 +145,20 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(ab.shape, ab_from_matrix.shape)
|
||||
self.assertTrue(torch.allclose(ab, ab_from_matrix))
|
||||
|
||||
def test_matrix_to_quaternion_corner_case(self):
|
||||
"""Check no bad gradients from sqrt(0)."""
|
||||
matrix = torch.eye(3, requires_grad=True)
|
||||
target = torch.Tensor([0.984808, 0, 0.174, 0])
|
||||
|
||||
optimizer = torch.optim.Adam([matrix], lr=0.05)
|
||||
optimizer.zero_grad()
|
||||
q = matrix_to_quaternion(matrix)
|
||||
loss = torch.sum((q - target) ** 2)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
self.assertClose(matrix, 0.95 * torch.eye(3))
|
||||
|
||||
def test_quaternion_application(self):
|
||||
"""Applying a quaternion is the same as applying the matrix."""
|
||||
quaternions = random_quaternions(3, torch.float64, requires_grad=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user