Address black + isort fbsource linter warnings

Summary: Address black + isort fbsource linter warnings from D20558374 (previous diff)

Reviewed By: nikhilaravi

Differential Revision: D20558373

fbshipit-source-id: d3607de4a01fb24c0d5269634563a7914bddf1c8
This commit is contained in:
Patrick Labatut
2020-03-29 14:46:33 -07:00
committed by Facebook GitHub Bot
parent eb512ffde3
commit d57daa6f85
110 changed files with 705 additions and 1850 deletions

View File

@@ -22,4 +22,5 @@ from .so3 import (
)
from .transform3d import Rotate, RotateAxisAngle, Scale, Transform3d, Translate
__all__ = [k for k in globals().keys() if not k.startswith("_")]

View File

@@ -2,6 +2,7 @@
import functools
from typing import Optional
import torch
@@ -155,9 +156,7 @@ def euler_angles_to_matrix(euler_angles, convention: str):
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = map(
_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)
)
matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
return functools.reduce(torch.matmul, matrices)
@@ -246,10 +245,7 @@ def matrix_to_euler_angles(matrix, convention: str):
def random_quaternions(
n: int,
dtype: Optional[torch.dtype] = None,
device=None,
requires_grad=False,
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
):
"""
Generate random quaternions representing rotations,
@@ -266,19 +262,14 @@ def random_quaternions(
Returns:
Quaternions as tensor of shape (N, 4).
"""
o = torch.randn(
(n, 4), dtype=dtype, device=device, requires_grad=requires_grad
)
o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
s = (o * o).sum(1)
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
return o
def random_rotations(
n: int,
dtype: Optional[torch.dtype] = None,
device=None,
requires_grad=False,
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
):
"""
Generate random rotations as 3x3 rotation matrices.

View File

@@ -3,6 +3,7 @@
import torch
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
@@ -65,9 +66,7 @@ def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
rot_trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
if ((rot_trace < -1.0 - eps) + (rot_trace > 3.0 + eps)).any():
raise ValueError(
"A matrix has trace outside valid range [-1-eps,3+eps]."
)
raise ValueError("A matrix has trace outside valid range [-1-eps,3+eps].")
# clamp to valid range
rot_trace = torch.clamp(rot_trace, -1.0, 3.0)

View File

@@ -3,6 +3,7 @@
import math
import warnings
from typing import Optional
import torch
from .rotation_conversions import _axis_angle_rotation
@@ -230,9 +231,7 @@ class Transform3d:
# the transformations with get_matrix(), this correctly
# right-multiplies by the inverse of self._matrix
# at the end of the composition.
tinv._transforms = [
t.inverse() for t in reversed(self._transforms)
]
tinv._transforms = [t.inverse() for t in reversed(self._transforms)]
last = Transform3d(device=self.device)
last._matrix = i_matrix
tinv._transforms.append(last)
@@ -334,9 +333,7 @@ class Transform3d:
return self.compose(Scale(device=self.device, *args, **kwargs))
def rotate_axis_angle(self, *args, **kwargs):
return self.compose(
RotateAxisAngle(device=self.device, *args, **kwargs)
)
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))
def clone(self):
"""
@@ -388,9 +385,7 @@ class Transform3d:
class Translate(Transform3d):
def __init__(
self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"
):
def __init__(self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"):
"""
Create a new Transform3d representing 3D translations.
@@ -424,9 +419,7 @@ class Translate(Transform3d):
class Scale(Transform3d):
def __init__(
self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"
):
def __init__(self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"):
"""
A Transform3d representing a scaling operation, with different scale
factors along each coordinate axis.
@@ -444,9 +437,7 @@ class Scale(Transform3d):
- 1D torch tensor
"""
super().__init__(device=device)
xyz = _handle_input(
x, y, z, dtype, device, "scale", allow_singleton=True
)
xyz = _handle_input(x, y, z, dtype, device, "scale", allow_singleton=True)
N = xyz.shape[0]
# TODO: Can we do this all in one go somehow?
@@ -469,11 +460,7 @@ class Scale(Transform3d):
class Rotate(Transform3d):
def __init__(
self,
R,
dtype=torch.float32,
device: str = "cpu",
orthogonal_tol: float = 1e-5,
self, R, dtype=torch.float32, device: str = "cpu", orthogonal_tol: float = 1e-5
):
"""
Create a new Transform3d representing 3D rotation using a rotation
@@ -562,9 +549,7 @@ def _handle_coord(c, dtype, device):
return c
def _handle_input(
x, y, z, dtype, device, name: str, allow_singleton: bool = False
):
def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = False):
"""
Helper function to handle parsing logic for building transforms. The output
is always a tensor of shape (N, 3), but there are several types of allowed