From cd5af2521a56bf3f8feca2926bbaa55cae3b3604 Mon Sep 17 00:00:00 2001 From: Dave Schnizlein Date: Wed, 19 May 2021 10:02:47 -0700 Subject: [PATCH] Update Rotate transform to use device of input rotation Summary: Currently the Rotate transform does not consider the R's device at all, resulting in errors if you're expecting it to be on cuda but it gets the default casting to cpu. This updates the transform to respect R's device. Reviewed By: nikhilaravi Differential Revision: D27828118 fbshipit-source-id: ddd99f73eadbd990688eb22f3d1ffbacbe168c81 --- pytorch3d/transforms/transform3d.py | 38 +++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index c954e749..23a8aa26 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -434,7 +434,7 @@ class Transform3d: class Translate(Transform3d): - def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"): + def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None): """ Create a new Transform3d representing 3D translations. @@ -448,11 +448,11 @@ class Translate(Transform3d): - A torch scalar - A 1D torch tensor """ - super().__init__(device=device) xyz = _handle_input(x, y, z, dtype, device, "Translate") + super().__init__(device=xyz.device) N = xyz.shape[0] - mat = torch.eye(4, dtype=dtype, device=device) + mat = torch.eye(4, dtype=dtype, device=self.device) mat = mat.view(1, 4, 4).repeat(N, 1, 1) mat[:, 3, :3] = xyz self._matrix = mat @@ -468,7 +468,7 @@ class Translate(Transform3d): class Scale(Transform3d): - def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"): + def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None): """ A Transform3d representing a scaling operation, with different scale factors along each coordinate axis. @@ -485,12 +485,12 @@ class Scale(Transform3d): - torch scalar - 1D torch tensor """ - super().__init__(device=device) xyz = _handle_input(x, y, z, dtype, device, "scale", allow_singleton=True) + super().__init__(device=xyz.device) N = xyz.shape[0] # TODO: Can we do this all in one go somehow? - mat = torch.eye(4, dtype=dtype, device=device) + mat = torch.eye(4, dtype=dtype, device=self.device) mat = mat.view(1, 4, 4).repeat(N, 1, 1) mat[:, 0, 0] = xyz[:, 0] mat[:, 1, 1] = xyz[:, 1] @@ -509,7 +509,7 @@ class Scale(Transform3d): class Rotate(Transform3d): def __init__( - self, R, dtype=torch.float32, device="cpu", orthogonal_tol: float = 1e-5 + self, R, dtype=torch.float32, device=None, orthogonal_tol: float = 1e-5 ): """ Create a new Transform3d representing 3D rotation using a rotation @@ -520,6 +520,7 @@ class Rotate(Transform3d): orthogonal_tol: tolerance for the test of the orthogonality of R """ + device = _get_device(R, device) super().__init__(device=device) if R.dim() == 2: R = R[None] @@ -548,7 +549,7 @@ class RotateAxisAngle(Rotate): axis: str = "X", degrees: bool = True, dtype=torch.float64, - device="cpu", + device=None, ): """ Create a new Transform3d representing 3D rotation about an axis @@ -578,7 +579,7 @@ class RotateAxisAngle(Rotate): # is for transforming column vectors. Therefore we transpose this matrix. # R will always be of shape (N, 3, 3) R = _axis_angle_rotation(axis, angle).transpose(1, 2) - super().__init__(device=device, R=R) + super().__init__(device=angle.device, R=R) def _handle_coord(c, dtype, device): @@ -595,9 +596,24 @@ def _handle_coord(c, dtype, device): c = torch.tensor(c, dtype=dtype, device=device) if c.dim() == 0: c = c.view(1) + if c.device != device: + c = c.to(device=device) return c +def _get_device(x, device=None): + if device is not None: + # User overriding device, leave + device = device + elif torch.is_tensor(x): + # Set device based on input tensor + device = x.device + else: + # Default device is cpu + device = "cpu" + return device + + def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = False): """ Helper function to handle parsing logic for building transforms. The output @@ -626,6 +642,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal Returns: xyz: Tensor of shape (N, 3) """ + device = _get_device(x, device) # If x is actually a tensor of shape (N, 3) then just return it if torch.is_tensor(x) and x.dim() == 2: if x.shape[1] != 3: @@ -634,7 +651,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal if y is not None or z is not None: msg = "Expected y and z to be None (in %s)" % name raise ValueError(msg) - return x + return x.to(device=device) if allow_singleton and y is None and z is None: y = x @@ -665,6 +682,7 @@ def _handle_angle_input(x, dtype, device, name: str): - Python scalar - Torch scalar """ + device = _get_device(x, device) if torch.is_tensor(x) and x.dim() > 1: msg = "Expected tensor of shape (N,); got %r (in %s)" raise ValueError(msg % (x.shape, name))