mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
remove requires_grad from random rotations
Summary: Because rotations and (rotation) quaternions live on curved manifolds, it doesn't make sense to optimize them directly. Having a prominent option to require gradient on random ones may cause people to try, and isn't particularly useful. Reviewed By: theschnitz Differential Revision: D29160734 fbshipit-source-id: fc9e320672349fe334747c5b214655882a460a62
This commit is contained in:
parent
31c448a95d
commit
ce60d4b00e
@ -282,9 +282,7 @@ def matrix_to_euler_angles(matrix, convention: str):
|
|||||||
return torch.stack(o, -1)
|
return torch.stack(o, -1)
|
||||||
|
|
||||||
|
|
||||||
def random_quaternions(
|
def random_quaternions(n: int, dtype: Optional[torch.dtype] = None, device=None):
|
||||||
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Generate random quaternions representing rotations,
|
Generate random quaternions representing rotations,
|
||||||
i.e. versors with nonnegative real part.
|
i.e. versors with nonnegative real part.
|
||||||
@ -294,21 +292,17 @@ def random_quaternions(
|
|||||||
dtype: Type to return.
|
dtype: Type to return.
|
||||||
device: Desired device of returned tensor. Default:
|
device: Desired device of returned tensor. Default:
|
||||||
uses the current device for the default tensor type.
|
uses the current device for the default tensor type.
|
||||||
requires_grad: Whether the resulting tensor should have the gradient
|
|
||||||
flag set.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Quaternions as tensor of shape (N, 4).
|
Quaternions as tensor of shape (N, 4).
|
||||||
"""
|
"""
|
||||||
o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
|
o = torch.randn((n, 4), dtype=dtype, device=device)
|
||||||
s = (o * o).sum(1)
|
s = (o * o).sum(1)
|
||||||
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
|
||||||
def random_rotations(
|
def random_rotations(n: int, dtype: Optional[torch.dtype] = None, device=None):
|
||||||
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Generate random rotations as 3x3 rotation matrices.
|
Generate random rotations as 3x3 rotation matrices.
|
||||||
|
|
||||||
@ -317,21 +311,15 @@ def random_rotations(
|
|||||||
dtype: Type to return.
|
dtype: Type to return.
|
||||||
device: Device of returned tensor. Default: if None,
|
device: Device of returned tensor. Default: if None,
|
||||||
uses the current device for the default tensor type.
|
uses the current device for the default tensor type.
|
||||||
requires_grad: Whether the resulting tensor should have the gradient
|
|
||||||
flag set.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Rotation matrices as tensor of shape (n, 3, 3).
|
Rotation matrices as tensor of shape (n, 3, 3).
|
||||||
"""
|
"""
|
||||||
quaternions = random_quaternions(
|
quaternions = random_quaternions(n, dtype=dtype, device=device)
|
||||||
n, dtype=dtype, device=device, requires_grad=requires_grad
|
|
||||||
)
|
|
||||||
return quaternion_to_matrix(quaternions)
|
return quaternion_to_matrix(quaternions)
|
||||||
|
|
||||||
|
|
||||||
def random_rotation(
|
def random_rotation(dtype: Optional[torch.dtype] = None, device=None):
|
||||||
dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Generate a single random 3x3 rotation matrix.
|
Generate a single random 3x3 rotation matrix.
|
||||||
|
|
||||||
@ -339,13 +327,11 @@ def random_rotation(
|
|||||||
dtype: Type to return
|
dtype: Type to return
|
||||||
device: Device of returned tensor. Default: if None,
|
device: Device of returned tensor. Default: if None,
|
||||||
uses the current device for the default tensor type
|
uses the current device for the default tensor type
|
||||||
requires_grad: Whether the resulting tensor should have the gradient
|
|
||||||
flag set
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Rotation matrix as tensor of shape (3, 3).
|
Rotation matrix as tensor of shape (3, 3).
|
||||||
"""
|
"""
|
||||||
return random_rotations(1, dtype, device, requires_grad)[0]
|
return random_rotations(1, dtype, device)[0]
|
||||||
|
|
||||||
|
|
||||||
def standardize_quaternion(quaternions):
|
def standardize_quaternion(quaternions):
|
||||||
|
@ -76,7 +76,8 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_quat_grad_exists(self):
|
def test_quat_grad_exists(self):
|
||||||
"""Quaternion calculations are differentiable."""
|
"""Quaternion calculations are differentiable."""
|
||||||
rotation = random_rotation(requires_grad=True)
|
rotation = random_rotation()
|
||||||
|
rotation.requires_grad = True
|
||||||
modified = quaternion_to_matrix(matrix_to_quaternion(rotation))
|
modified = quaternion_to_matrix(matrix_to_quaternion(rotation))
|
||||||
[g] = torch.autograd.grad(modified.sum(), rotation)
|
[g] = torch.autograd.grad(modified.sum(), rotation)
|
||||||
self.assertTrue(torch.isfinite(g).all())
|
self.assertTrue(torch.isfinite(g).all())
|
||||||
@ -131,7 +132,8 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_euler_grad_exists(self):
|
def test_euler_grad_exists(self):
|
||||||
"""Euler angle calculations are differentiable."""
|
"""Euler angle calculations are differentiable."""
|
||||||
rotation = random_rotation(dtype=torch.float64, requires_grad=True)
|
rotation = random_rotation(dtype=torch.float64)
|
||||||
|
rotation.requires_grad = True
|
||||||
for convention in self._all_euler_angle_conventions():
|
for convention in self._all_euler_angle_conventions():
|
||||||
euler_angles = matrix_to_euler_angles(rotation, convention)
|
euler_angles = matrix_to_euler_angles(rotation, convention)
|
||||||
mdata = euler_angles_to_matrix(euler_angles, convention)
|
mdata = euler_angles_to_matrix(euler_angles, convention)
|
||||||
@ -218,7 +220,8 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
def test_quaternion_application(self):
|
def test_quaternion_application(self):
|
||||||
"""Applying a quaternion is the same as applying the matrix."""
|
"""Applying a quaternion is the same as applying the matrix."""
|
||||||
quaternions = random_quaternions(3, torch.float64, requires_grad=True)
|
quaternions = random_quaternions(3, torch.float64)
|
||||||
|
quaternions.requires_grad = True
|
||||||
matrices = quaternion_to_matrix(quaternions)
|
matrices = quaternion_to_matrix(quaternions)
|
||||||
points = torch.randn(3, 3, dtype=torch.float64, requires_grad=True)
|
points = torch.randn(3, 3, dtype=torch.float64, requires_grad=True)
|
||||||
transform1 = quaternion_apply(quaternions, points)
|
transform1 = quaternion_apply(quaternions, points)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user