Update Rotate transform to use device of input rotation

Summary: Currently the Rotate transform does not consider the R's device at all, resulting in errors if you're expecting it to be on cuda but it gets the default casting to cpu. This updates the transform to respect R's device.

Reviewed By: nikhilaravi

Differential Revision: D27828118

fbshipit-source-id: ddd99f73eadbd990688eb22f3d1ffbacbe168c81
This commit is contained in:
Dave Schnizlein 2021-05-19 10:02:47 -07:00 committed by Facebook GitHub Bot
parent c9dea62162
commit cd5af2521a

View File

@ -434,7 +434,7 @@ class Transform3d:
class Translate(Transform3d):
def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None):
"""
Create a new Transform3d representing 3D translations.
@ -448,11 +448,11 @@ class Translate(Transform3d):
- A torch scalar
- A 1D torch tensor
"""
super().__init__(device=device)
xyz = _handle_input(x, y, z, dtype, device, "Translate")
super().__init__(device=xyz.device)
N = xyz.shape[0]
mat = torch.eye(4, dtype=dtype, device=device)
mat = torch.eye(4, dtype=dtype, device=self.device)
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
mat[:, 3, :3] = xyz
self._matrix = mat
@ -468,7 +468,7 @@ class Translate(Transform3d):
class Scale(Transform3d):
def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"):
def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None):
"""
A Transform3d representing a scaling operation, with different scale
factors along each coordinate axis.
@ -485,12 +485,12 @@ class Scale(Transform3d):
- torch scalar
- 1D torch tensor
"""
super().__init__(device=device)
xyz = _handle_input(x, y, z, dtype, device, "scale", allow_singleton=True)
super().__init__(device=xyz.device)
N = xyz.shape[0]
# TODO: Can we do this all in one go somehow?
mat = torch.eye(4, dtype=dtype, device=device)
mat = torch.eye(4, dtype=dtype, device=self.device)
mat = mat.view(1, 4, 4).repeat(N, 1, 1)
mat[:, 0, 0] = xyz[:, 0]
mat[:, 1, 1] = xyz[:, 1]
@ -509,7 +509,7 @@ class Scale(Transform3d):
class Rotate(Transform3d):
def __init__(
self, R, dtype=torch.float32, device="cpu", orthogonal_tol: float = 1e-5
self, R, dtype=torch.float32, device=None, orthogonal_tol: float = 1e-5
):
"""
Create a new Transform3d representing 3D rotation using a rotation
@ -520,6 +520,7 @@ class Rotate(Transform3d):
orthogonal_tol: tolerance for the test of the orthogonality of R
"""
device = _get_device(R, device)
super().__init__(device=device)
if R.dim() == 2:
R = R[None]
@ -548,7 +549,7 @@ class RotateAxisAngle(Rotate):
axis: str = "X",
degrees: bool = True,
dtype=torch.float64,
device="cpu",
device=None,
):
"""
Create a new Transform3d representing 3D rotation about an axis
@ -578,7 +579,7 @@ class RotateAxisAngle(Rotate):
# is for transforming column vectors. Therefore we transpose this matrix.
# R will always be of shape (N, 3, 3)
R = _axis_angle_rotation(axis, angle).transpose(1, 2)
super().__init__(device=device, R=R)
super().__init__(device=angle.device, R=R)
def _handle_coord(c, dtype, device):
@ -595,9 +596,24 @@ def _handle_coord(c, dtype, device):
c = torch.tensor(c, dtype=dtype, device=device)
if c.dim() == 0:
c = c.view(1)
if c.device != device:
c = c.to(device=device)
return c
def _get_device(x, device=None):
if device is not None:
# User overriding device, leave
device = device
elif torch.is_tensor(x):
# Set device based on input tensor
device = x.device
else:
# Default device is cpu
device = "cpu"
return device
def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = False):
"""
Helper function to handle parsing logic for building transforms. The output
@ -626,6 +642,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
Returns:
xyz: Tensor of shape (N, 3)
"""
device = _get_device(x, device)
# If x is actually a tensor of shape (N, 3) then just return it
if torch.is_tensor(x) and x.dim() == 2:
if x.shape[1] != 3:
@ -634,7 +651,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
if y is not None or z is not None:
msg = "Expected y and z to be None (in %s)" % name
raise ValueError(msg)
return x
return x.to(device=device)
if allow_singleton and y is None and z is None:
y = x
@ -665,6 +682,7 @@ def _handle_angle_input(x, dtype, device, name: str):
- Python scalar
- Torch scalar
"""
device = _get_device(x, device)
if torch.is_tensor(x) and x.dim() > 1:
msg = "Expected tensor of shape (N,); got %r (in %s)"
raise ValueError(msg % (x.shape, name))