Address black + isort fbsource linter warnings

Summary: Address black + isort fbsource linter warnings from D20558374 (previous diff)

Reviewed By: nikhilaravi

Differential Revision: D20558373

fbshipit-source-id: d3607de4a01fb24c0d5269634563a7914bddf1c8
This commit is contained in:
Patrick Labatut
2020-03-29 14:46:33 -07:00
committed by Facebook GitHub Bot
parent eb512ffde3
commit d57daa6f85
110 changed files with 705 additions and 1850 deletions

View File

@@ -2,6 +2,7 @@
import functools
from typing import Optional
import torch
@@ -155,9 +156,7 @@ def euler_angles_to_matrix(euler_angles, convention: str):
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = map(
_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)
)
matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
return functools.reduce(torch.matmul, matrices)
@@ -246,10 +245,7 @@ def matrix_to_euler_angles(matrix, convention: str):
def random_quaternions(
n: int,
dtype: Optional[torch.dtype] = None,
device=None,
requires_grad=False,
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
):
"""
Generate random quaternions representing rotations,
@@ -266,19 +262,14 @@ def random_quaternions(
Returns:
Quaternions as tensor of shape (N, 4).
"""
o = torch.randn(
(n, 4), dtype=dtype, device=device, requires_grad=requires_grad
)
o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
s = (o * o).sum(1)
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
return o
def random_rotations(
n: int,
dtype: Optional[torch.dtype] = None,
device=None,
requires_grad=False,
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
):
"""
Generate random rotations as 3x3 rotation matrices.