mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-03 12:22:49 +08:00
Update Rotate transform to use device of input rotation
Summary: Currently the Rotate transform does not consider the R's device at all, resulting in errors if you're expecting it to be on cuda but it gets the default casting to cpu. This updates the transform to respect R's device. Reviewed By: nikhilaravi Differential Revision: D27828118 fbshipit-source-id: ddd99f73eadbd990688eb22f3d1ffbacbe168c81
This commit is contained in:
parent
c9dea62162
commit
cd5af2521a
@ -434,7 +434,7 @@ class Transform3d:
|
||||
|
||||
|
||||
class Translate(Transform3d):
|
||||
def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
|
||||
def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None):
|
||||
"""
|
||||
Create a new Transform3d representing 3D translations.
|
||||
|
||||
@ -448,11 +448,11 @@ class Translate(Transform3d):
|
||||
- A torch scalar
|
||||
- A 1D torch tensor
|
||||
"""
|
||||
super().__init__(device=device)
|
||||
xyz = _handle_input(x, y, z, dtype, device, "Translate")
|
||||
super().__init__(device=xyz.device)
|
||||
N = xyz.shape[0]
|
||||
|
||||
mat = torch.eye(4, dtype=dtype, device=device)
|
||||
mat = torch.eye(4, dtype=dtype, device=self.device)
|
||||
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
|
||||
mat[:, 3, :3] = xyz
|
||||
self._matrix = mat
|
||||
@ -468,7 +468,7 @@ class Translate(Transform3d):
|
||||
|
||||
|
||||
class Scale(Transform3d):
|
||||
def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
|
||||
def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None):
|
||||
"""
|
||||
A Transform3d representing a scaling operation, with different scale
|
||||
factors along each coordinate axis.
|
||||
@ -485,12 +485,12 @@ class Scale(Transform3d):
|
||||
- torch scalar
|
||||
- 1D torch tensor
|
||||
"""
|
||||
super().__init__(device=device)
|
||||
xyz = _handle_input(x, y, z, dtype, device, "scale", allow_singleton=True)
|
||||
super().__init__(device=xyz.device)
|
||||
N = xyz.shape[0]
|
||||
|
||||
# TODO: Can we do this all in one go somehow?
|
||||
mat = torch.eye(4, dtype=dtype, device=device)
|
||||
mat = torch.eye(4, dtype=dtype, device=self.device)
|
||||
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
|
||||
mat[:, 0, 0] = xyz[:, 0]
|
||||
mat[:, 1, 1] = xyz[:, 1]
|
||||
@ -509,7 +509,7 @@ class Scale(Transform3d):
|
||||
|
||||
class Rotate(Transform3d):
|
||||
def __init__(
|
||||
self, R, dtype=torch.float32, device="cpu", orthogonal_tol: float = 1e-5
|
||||
self, R, dtype=torch.float32, device=None, orthogonal_tol: float = 1e-5
|
||||
):
|
||||
"""
|
||||
Create a new Transform3d representing 3D rotation using a rotation
|
||||
@ -520,6 +520,7 @@ class Rotate(Transform3d):
|
||||
orthogonal_tol: tolerance for the test of the orthogonality of R
|
||||
|
||||
"""
|
||||
device = _get_device(R, device)
|
||||
super().__init__(device=device)
|
||||
if R.dim() == 2:
|
||||
R = R[None]
|
||||
@ -548,7 +549,7 @@ class RotateAxisAngle(Rotate):
|
||||
axis: str = "X",
|
||||
degrees: bool = True,
|
||||
dtype=torch.float64,
|
||||
device="cpu",
|
||||
device=None,
|
||||
):
|
||||
"""
|
||||
Create a new Transform3d representing 3D rotation about an axis
|
||||
@ -578,7 +579,7 @@ class RotateAxisAngle(Rotate):
|
||||
# is for transforming column vectors. Therefore we transpose this matrix.
|
||||
# R will always be of shape (N, 3, 3)
|
||||
R = _axis_angle_rotation(axis, angle).transpose(1, 2)
|
||||
super().__init__(device=device, R=R)
|
||||
super().__init__(device=angle.device, R=R)
|
||||
|
||||
|
||||
def _handle_coord(c, dtype, device):
|
||||
@ -595,9 +596,24 @@ def _handle_coord(c, dtype, device):
|
||||
c = torch.tensor(c, dtype=dtype, device=device)
|
||||
if c.dim() == 0:
|
||||
c = c.view(1)
|
||||
if c.device != device:
|
||||
c = c.to(device=device)
|
||||
return c
|
||||
|
||||
|
||||
def _get_device(x, device=None):
|
||||
if device is not None:
|
||||
# User overriding device, leave
|
||||
device = device
|
||||
elif torch.is_tensor(x):
|
||||
# Set device based on input tensor
|
||||
device = x.device
|
||||
else:
|
||||
# Default device is cpu
|
||||
device = "cpu"
|
||||
return device
|
||||
|
||||
|
||||
def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = False):
|
||||
"""
|
||||
Helper function to handle parsing logic for building transforms. The output
|
||||
@ -626,6 +642,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
|
||||
Returns:
|
||||
xyz: Tensor of shape (N, 3)
|
||||
"""
|
||||
device = _get_device(x, device)
|
||||
# If x is actually a tensor of shape (N, 3) then just return it
|
||||
if torch.is_tensor(x) and x.dim() == 2:
|
||||
if x.shape[1] != 3:
|
||||
@ -634,7 +651,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
|
||||
if y is not None or z is not None:
|
||||
msg = "Expected y and z to be None (in %s)" % name
|
||||
raise ValueError(msg)
|
||||
return x
|
||||
return x.to(device=device)
|
||||
|
||||
if allow_singleton and y is None and z is None:
|
||||
y = x
|
||||
@ -665,6 +682,7 @@ def _handle_angle_input(x, dtype, device, name: str):
|
||||
- Python scalar
|
||||
- Torch scalar
|
||||
"""
|
||||
device = _get_device(x, device)
|
||||
if torch.is_tensor(x) and x.dim() > 1:
|
||||
msg = "Expected tensor of shape (N,); got %r (in %s)"
|
||||
raise ValueError(msg % (x.shape, name))
|
||||
|
Loading…
x
Reference in New Issue
Block a user