diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index c954e749..23a8aa26 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -434,7 +434,7 @@ class Transform3d: class Translate(Transform3d): - def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"): + def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None): """ Create a new Transform3d representing 3D translations. @@ -448,11 +448,11 @@ class Translate(Transform3d): - A torch scalar - A 1D torch tensor """ - super().__init__(device=device) xyz = _handle_input(x, y, z, dtype, device, "Translate") + super().__init__(device=xyz.device) N = xyz.shape[0] - mat = torch.eye(4, dtype=dtype, device=device) + mat = torch.eye(4, dtype=dtype, device=self.device) mat = mat.view(1, 4, 4).repeat(N, 1, 1) mat[:, 3, :3] = xyz self._matrix = mat @@ -468,7 +468,7 @@ class Translate(Transform3d): class Scale(Transform3d): - def __init__(self, x, y=None, z=None, dtype=torch.float32, device="cpu"): + def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None): """ A Transform3d representing a scaling operation, with different scale factors along each coordinate axis. @@ -485,12 +485,12 @@ class Scale(Transform3d): - torch scalar - 1D torch tensor """ - super().__init__(device=device) xyz = _handle_input(x, y, z, dtype, device, "scale", allow_singleton=True) + super().__init__(device=xyz.device) N = xyz.shape[0] # TODO: Can we do this all in one go somehow? - mat = torch.eye(4, dtype=dtype, device=device) + mat = torch.eye(4, dtype=dtype, device=self.device) mat = mat.view(1, 4, 4).repeat(N, 1, 1) mat[:, 0, 0] = xyz[:, 0] mat[:, 1, 1] = xyz[:, 1] @@ -509,7 +509,7 @@ class Scale(Transform3d): class Rotate(Transform3d): def __init__( - self, R, dtype=torch.float32, device="cpu", orthogonal_tol: float = 1e-5 + self, R, dtype=torch.float32, device=None, orthogonal_tol: float = 1e-5 ): """ Create a new Transform3d representing 3D rotation using a rotation @@ -520,6 +520,7 @@ class Rotate(Transform3d): orthogonal_tol: tolerance for the test of the orthogonality of R """ + device = _get_device(R, device) super().__init__(device=device) if R.dim() == 2: R = R[None] @@ -548,7 +549,7 @@ class RotateAxisAngle(Rotate): axis: str = "X", degrees: bool = True, dtype=torch.float64, - device="cpu", + device=None, ): """ Create a new Transform3d representing 3D rotation about an axis @@ -578,7 +579,7 @@ class RotateAxisAngle(Rotate): # is for transforming column vectors. Therefore we transpose this matrix. # R will always be of shape (N, 3, 3) R = _axis_angle_rotation(axis, angle).transpose(1, 2) - super().__init__(device=device, R=R) + super().__init__(device=angle.device, R=R) def _handle_coord(c, dtype, device): @@ -595,9 +596,24 @@ def _handle_coord(c, dtype, device): c = torch.tensor(c, dtype=dtype, device=device) if c.dim() == 0: c = c.view(1) + if c.device != device: + c = c.to(device=device) return c +def _get_device(x, device=None): + if device is not None: + # User overriding device, leave + device = device + elif torch.is_tensor(x): + # Set device based on input tensor + device = x.device + else: + # Default device is cpu + device = "cpu" + return device + + def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = False): """ Helper function to handle parsing logic for building transforms. The output @@ -626,6 +642,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal Returns: xyz: Tensor of shape (N, 3) """ + device = _get_device(x, device) # If x is actually a tensor of shape (N, 3) then just return it if torch.is_tensor(x) and x.dim() == 2: if x.shape[1] != 3: @@ -634,7 +651,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal if y is not None or z is not None: msg = "Expected y and z to be None (in %s)" % name raise ValueError(msg) - return x + return x.to(device=device) if allow_singleton and y is None and z is None: y = x @@ -665,6 +682,7 @@ def _handle_angle_input(x, dtype, device, name: str): - Python scalar - Torch scalar """ + device = _get_device(x, device) if torch.is_tensor(x) and x.dim() > 1: msg = "Expected tensor of shape (N,); got %r (in %s)" raise ValueError(msg % (x.shape, name))