diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index 23a8aa26..da841a8e 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -6,6 +6,7 @@ from typing import List, Optional, Union import torch +from ..common.types import Device, get_device, make_device from .rotation_conversions import _axis_angle_rotation @@ -137,7 +138,7 @@ class Transform3d: def __init__( self, dtype: torch.dtype = torch.float32, - device="cpu", + device: Device = "cpu", matrix: Optional[torch.Tensor] = None, ): """ @@ -167,7 +168,7 @@ class Transform3d: self._transforms = [] # store transforms to compose self._lu = None - self.device = device + self.device = make_device(device) def __len__(self): return self.get_matrix().shape[0] @@ -398,7 +399,12 @@ class Transform3d: other._transforms = [t.clone() for t in self._transforms] return other - def to(self, device, copy: bool = False, dtype=None): + def to( + self, + device: Device, + copy: bool = False, + dtype: Optional[torch.dtype] = None, + ): """ Match functionality of torch.Tensor.to() If copy = True or the self Tensor is on a different device, the @@ -407,7 +413,7 @@ class Transform3d: then self is returned. Args: - device: Device id for the new tensor. + device: Device (as str or torch.device) for the new tensor. copy: Boolean indicator whether or not to clone self. Default False. dtype: If not None, casts the internal tensor variables to a given torch.dtype. @@ -415,26 +421,37 @@ class Transform3d: Returns: Transform3d object. """ - if not copy and self.device == device: + device_ = make_device(device) + if not copy and self.device == device_: return self + other = self.clone() - if self.device != device: - other.device = device - other._matrix = self._matrix.to(device=device, dtype=dtype) - other._transforms = [ - t.to(device, copy=copy, dtype=dtype) for t in other._transforms - ] + if self.device == device_: + return other + + other.device = device_ + other._matrix = self._matrix.to(device=device_, dtype=dtype) + other._transforms = [ + t.to(device_, copy=copy, dtype=dtype) for t in other._transforms + ] return other def cpu(self): - return self.to(torch.device("cpu")) + return self.to("cpu") def cuda(self): - return self.to(torch.device("cuda")) + return self.to("cuda") class Translate(Transform3d): - def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None): + def __init__( + self, + x, + y=None, + z=None, + dtype: torch.dtype = torch.float32, + device: Optional[Device] = None, + ): """ Create a new Transform3d representing 3D translations. @@ -468,7 +485,14 @@ class Translate(Transform3d): class Scale(Transform3d): - def __init__(self, x, y=None, z=None, dtype=torch.float32, device=None): + def __init__( + self, + x, + y=None, + z=None, + dtype: torch.dtype = torch.float32, + device: Optional[Device] = None, + ): """ A Transform3d representing a scaling operation, with different scale factors along each coordinate axis. @@ -509,7 +533,11 @@ class Scale(Transform3d): class Rotate(Transform3d): def __init__( - self, R, dtype=torch.float32, device=None, orthogonal_tol: float = 1e-5 + self, + R: torch.Tensor, + dtype: torch.dtype = torch.float32, + device: Optional[Device] = None, + orthogonal_tol: float = 1e-5, ): """ Create a new Transform3d representing 3D rotation using a rotation @@ -520,17 +548,17 @@ class Rotate(Transform3d): orthogonal_tol: tolerance for the test of the orthogonality of R """ - device = _get_device(R, device) - super().__init__(device=device) + device_ = get_device(R, device) + super().__init__(device=device_) if R.dim() == 2: R = R[None] if R.shape[-2:] != (3, 3): msg = "R must have shape (3, 3) or (N, 3, 3); got %s" raise ValueError(msg % repr(R.shape)) - R = R.to(dtype=dtype).to(device=device) + R = R.to(dtype=dtype).to(device=device_) _check_valid_rotation_matrix(R, tol=orthogonal_tol) N = R.shape[0] - mat = torch.eye(4, dtype=dtype, device=device) + mat = torch.eye(4, dtype=dtype, device=device_) mat = mat.view(1, 4, 4).repeat(N, 1, 1) mat[:, :3, :3] = R self._matrix = mat @@ -548,8 +576,8 @@ class RotateAxisAngle(Rotate): angle, axis: str = "X", degrees: bool = True, - dtype=torch.float64, - device=None, + dtype: torch.dtype = torch.float64, + device: Optional[Device] = None, ): """ Create a new Transform3d representing 3D rotation about an axis @@ -582,7 +610,7 @@ class RotateAxisAngle(Rotate): super().__init__(device=angle.device, R=R) -def _handle_coord(c, dtype, device): +def _handle_coord(c, dtype: torch.dtype, device: torch.device): """ Helper function for _handle_input. @@ -601,20 +629,15 @@ def _handle_coord(c, dtype, device): return c -def _get_device(x, device=None): - if device is not None: - # User overriding device, leave - device = device - elif torch.is_tensor(x): - # Set device based on input tensor - device = x.device - else: - # Default device is cpu - device = "cpu" - return device - - -def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = False): +def _handle_input( + x, + y, + z, + dtype: torch.dtype, + device: Optional[Device], + name: str, + allow_singleton: bool = False, +): """ Helper function to handle parsing logic for building transforms. The output is always a tensor of shape (N, 3), but there are several types of allowed @@ -642,7 +665,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal Returns: xyz: Tensor of shape (N, 3) """ - device = _get_device(x, device) + device_ = get_device(x, device) # If x is actually a tensor of shape (N, 3) then just return it if torch.is_tensor(x) and x.dim() == 2: if x.shape[1] != 3: @@ -651,14 +674,14 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal if y is not None or z is not None: msg = "Expected y and z to be None (in %s)" % name raise ValueError(msg) - return x.to(device=device) + return x.to(device=device_) if allow_singleton and y is None and z is None: y = x z = x # Convert all to 1D tensors - xyz = [_handle_coord(c, dtype, device) for c in [x, y, z]] + xyz = [_handle_coord(c, dtype, device_) for c in [x, y, z]] # Broadcast and concatenate sizes = [c.shape[0] for c in xyz] @@ -672,7 +695,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal return xyz -def _handle_angle_input(x, dtype, device, name: str): +def _handle_angle_input(x, dtype: torch.dtype, device: Optional[Device], name: str): """ Helper function for building a rotation function using angles. The output is always of shape (N,). @@ -682,12 +705,12 @@ def _handle_angle_input(x, dtype, device, name: str): - Python scalar - Torch scalar """ - device = _get_device(x, device) + device_ = get_device(x, device) if torch.is_tensor(x) and x.dim() > 1: msg = "Expected tensor of shape (N,); got %r (in %s)" raise ValueError(msg % (x.shape, name)) else: - return _handle_coord(x, dtype, device) + return _handle_coord(x, dtype, device_) def _broadcast_bmm(a, b): diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 6c2642d8..26e38ec1 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -20,10 +20,35 @@ class TestTransform(TestCaseMixin, unittest.TestCase): def test_to(self): tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]])) R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) - cpu_points = torch.rand(9, 3) - cuda_points = cpu_points.cuda() R = Rotate(R) t = Transform3d().compose(R, tr) + + cpu_device = torch.device("cpu") + + cpu_t = t.to("cpu") + self.assertEqual(cpu_device, cpu_t.device) + self.assertEqual(cpu_device, t.device) + self.assertIs(t, cpu_t) + + cpu_t = t.to(cpu_device) + self.assertEqual(cpu_device, cpu_t.device) + self.assertEqual(cpu_device, t.device) + self.assertIs(t, cpu_t) + + cuda_device = torch.device("cuda") + + cuda_t = t.to("cuda") + self.assertEqual(cuda_device, cuda_t.device) + self.assertEqual(cpu_device, t.device) + self.assertIsNot(t, cuda_t) + + cuda_t = t.to(cuda_device) + self.assertEqual(cuda_device, cuda_t.device) + self.assertEqual(cpu_device, t.device) + self.assertIsNot(t, cuda_t) + + cpu_points = torch.rand(9, 3) + cuda_points = cpu_points.cuda() for _ in range(3): t = t.cpu() t.transform_points(cpu_points)