mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 23:00:34 +08:00
Address black + isort fbsource linter warnings
Summary: Address black + isort fbsource linter warnings from D20558374 (previous diff) Reviewed By: nikhilaravi Differential Revision: D20558373 fbshipit-source-id: d3607de4a01fb24c0d5269634563a7914bddf1c8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
eb512ffde3
commit
d57daa6f85
@@ -22,4 +22,5 @@ from .so3 import (
|
||||
)
|
||||
from .transform3d import Rotate, RotateAxisAngle, Scale, Transform3d, Translate
|
||||
|
||||
|
||||
__all__ = [k for k in globals().keys() if not k.startswith("_")]
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import functools
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -155,9 +156,7 @@ def euler_angles_to_matrix(euler_angles, convention: str):
|
||||
for letter in convention:
|
||||
if letter not in ("X", "Y", "Z"):
|
||||
raise ValueError(f"Invalid letter {letter} in convention string.")
|
||||
matrices = map(
|
||||
_axis_angle_rotation, convention, torch.unbind(euler_angles, -1)
|
||||
)
|
||||
matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
|
||||
return functools.reduce(torch.matmul, matrices)
|
||||
|
||||
|
||||
@@ -246,10 +245,7 @@ def matrix_to_euler_angles(matrix, convention: str):
|
||||
|
||||
|
||||
def random_quaternions(
|
||||
n: int,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device=None,
|
||||
requires_grad=False,
|
||||
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
||||
):
|
||||
"""
|
||||
Generate random quaternions representing rotations,
|
||||
@@ -266,19 +262,14 @@ def random_quaternions(
|
||||
Returns:
|
||||
Quaternions as tensor of shape (N, 4).
|
||||
"""
|
||||
o = torch.randn(
|
||||
(n, 4), dtype=dtype, device=device, requires_grad=requires_grad
|
||||
)
|
||||
o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
|
||||
s = (o * o).sum(1)
|
||||
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
||||
return o
|
||||
|
||||
|
||||
def random_rotations(
|
||||
n: int,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device=None,
|
||||
requires_grad=False,
|
||||
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
||||
):
|
||||
"""
|
||||
Generate random rotations as 3x3 rotation matrices.
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
|
||||
|
||||
|
||||
@@ -65,9 +66,7 @@ def so3_rotation_angle(R, eps: float = 1e-4, cos_angle: bool = False):
|
||||
rot_trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
|
||||
|
||||
if ((rot_trace < -1.0 - eps) + (rot_trace > 3.0 + eps)).any():
|
||||
raise ValueError(
|
||||
"A matrix has trace outside valid range [-1-eps,3+eps]."
|
||||
)
|
||||
raise ValueError("A matrix has trace outside valid range [-1-eps,3+eps].")
|
||||
|
||||
# clamp to valid range
|
||||
rot_trace = torch.clamp(rot_trace, -1.0, 3.0)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from .rotation_conversions import _axis_angle_rotation
|
||||
@@ -230,9 +231,7 @@ class Transform3d:
|
||||
# the transformations with get_matrix(), this correctly
|
||||
# right-multiplies by the inverse of self._matrix
|
||||
# at the end of the composition.
|
||||
tinv._transforms = [
|
||||
t.inverse() for t in reversed(self._transforms)
|
||||
]
|
||||
tinv._transforms = [t.inverse() for t in reversed(self._transforms)]
|
||||
last = Transform3d(device=self.device)
|
||||
last._matrix = i_matrix
|
||||
tinv._transforms.append(last)
|
||||
@@ -334,9 +333,7 @@ class Transform3d:
|
||||
return self.compose(Scale(device=self.device, *args, **kwargs))
|
||||
|
||||
def rotate_axis_angle(self, *args, **kwargs):
|
||||
return self.compose(
|
||||
RotateAxisAngle(device=self.device, *args, **kwargs)
|
||||
)
|
||||
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))
|
||||
|
||||
def clone(self):
|
||||
"""
|
||||
@@ -388,9 +385,7 @@ class Transform3d:
|
||||
|
||||
|
||||
class Translate(Transform3d):
|
||||
def __init__(
|
||||
self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"
|
||||
):
|
||||
def __init__(self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"):
|
||||
"""
|
||||
Create a new Transform3d representing 3D translations.
|
||||
|
||||
@@ -424,9 +419,7 @@ class Translate(Transform3d):
|
||||
|
||||
|
||||
class Scale(Transform3d):
|
||||
def __init__(
|
||||
self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"
|
||||
):
|
||||
def __init__(self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"):
|
||||
"""
|
||||
A Transform3d representing a scaling operation, with different scale
|
||||
factors along each coordinate axis.
|
||||
@@ -444,9 +437,7 @@ class Scale(Transform3d):
|
||||
- 1D torch tensor
|
||||
"""
|
||||
super().__init__(device=device)
|
||||
xyz = _handle_input(
|
||||
x, y, z, dtype, device, "scale", allow_singleton=True
|
||||
)
|
||||
xyz = _handle_input(x, y, z, dtype, device, "scale", allow_singleton=True)
|
||||
N = xyz.shape[0]
|
||||
|
||||
# TODO: Can we do this all in one go somehow?
|
||||
@@ -469,11 +460,7 @@ class Scale(Transform3d):
|
||||
|
||||
class Rotate(Transform3d):
|
||||
def __init__(
|
||||
self,
|
||||
R,
|
||||
dtype=torch.float32,
|
||||
device: str = "cpu",
|
||||
orthogonal_tol: float = 1e-5,
|
||||
self, R, dtype=torch.float32, device: str = "cpu", orthogonal_tol: float = 1e-5
|
||||
):
|
||||
"""
|
||||
Create a new Transform3d representing 3D rotation using a rotation
|
||||
@@ -562,9 +549,7 @@ def _handle_coord(c, dtype, device):
|
||||
return c
|
||||
|
||||
|
||||
def _handle_input(
|
||||
x, y, z, dtype, device, name: str, allow_singleton: bool = False
|
||||
):
|
||||
def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = False):
|
||||
"""
|
||||
Helper function to handle parsing logic for building transforms. The output
|
||||
is always a tensor of shape (N, 3), but there are several types of allowed
|
||||
|
||||
Reference in New Issue
Block a user