From 44508ed0db269ed0b7c952fbee6bd09105a1c653 Mon Sep 17 00:00:00 2001 From: Patrick Labatut Date: Wed, 9 Jun 2021 15:48:56 -0700 Subject: [PATCH] Make Transform3d.to() not ignore dtype Summary: Make Transform3d.to() not ignore a different dtype when device is the same and no copy is requested. Fix other methods where dtype is ignored. Reviewed By: nikhilaravi Differential Revision: D28981171 fbshipit-source-id: 4528e6092f4a693aecbe8131ede985fca84e84cf --- pytorch3d/transforms/transform3d.py | 27 +++++++++++++++++---------- tests/test_transforms.py | 22 ++++++++++++++++++++++ 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index da841a8e..9273d416 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -162,13 +162,15 @@ class Transform3d: raise ValueError( '"matrix" has to be a tensor of shape (minibatch, 4, 4)' ) - # set the device from matrix + # set dtype and device from matrix + dtype = matrix.dtype device = matrix.device self._matrix = matrix.view(-1, 4, 4) self._transforms = [] # store transforms to compose self._lu = None self.device = make_device(device) + self.dtype = dtype def __len__(self): return self.get_matrix().shape[0] @@ -200,7 +202,7 @@ class Transform3d: Returns: A new Transform3d with the stored transforms """ - out = Transform3d(device=self.device) + out = Transform3d(dtype=self.dtype, device=self.device) out._matrix = self._matrix.clone() for other in others: if not isinstance(other, Transform3d): @@ -259,7 +261,7 @@ class Transform3d: transformation. """ - tinv = Transform3d(device=self.device) + tinv = Transform3d(dtype=self.dtype, device=self.device) if invert_composed: # first compose then invert @@ -278,7 +280,7 @@ class Transform3d: # right-multiplies by the inverse of self._matrix # at the end of the composition. tinv._transforms = [t.inverse() for t in reversed(self._transforms)] - last = Transform3d(device=self.device) + last = Transform3d(dtype=self.dtype, device=self.device) last._matrix = i_matrix tinv._transforms.append(last) else: @@ -291,7 +293,7 @@ class Transform3d: def stack(self, *others): transforms = [self] + list(others) matrix = torch.cat([t._matrix for t in transforms], dim=0) - out = Transform3d() + out = Transform3d(dtype=self.dtype, device=self.device) out._matrix = matrix return out @@ -392,7 +394,7 @@ class Transform3d: Returns: new Transforms object. """ - other = Transform3d(device=self.device) + other = Transform3d(dtype=self.dtype, device=self.device) if self._lu is not None: other._lu = [elem.clone() for elem in self._lu] other._matrix = self._matrix.clone() @@ -422,17 +424,22 @@ class Transform3d: Transform3d object. """ device_ = make_device(device) - if not copy and self.device == device_: + dtype_ = self.dtype if dtype is None else dtype + skip_to = self.device == device_ and self.dtype == dtype_ + + if not copy and skip_to: return self other = self.clone() - if self.device == device_: + + if skip_to: return other other.device = device_ - other._matrix = self._matrix.to(device=device_, dtype=dtype) + other.dtype = dtype_ + other._matrix = other._matrix.to(device=device_, dtype=dtype_) other._transforms = [ - t.to(device_, copy=copy, dtype=dtype) for t in other._transforms + t.to(device_, copy=copy, dtype=dtype_) for t in other._transforms ] return other diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 26e38ec1..490a3201 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -28,23 +28,45 @@ class TestTransform(TestCaseMixin, unittest.TestCase): cpu_t = t.to("cpu") self.assertEqual(cpu_device, cpu_t.device) self.assertEqual(cpu_device, t.device) + self.assertEqual(torch.float32, cpu_t.dtype) + self.assertEqual(torch.float32, t.dtype) self.assertIs(t, cpu_t) cpu_t = t.to(cpu_device) self.assertEqual(cpu_device, cpu_t.device) self.assertEqual(cpu_device, t.device) + self.assertEqual(torch.float32, cpu_t.dtype) + self.assertEqual(torch.float32, t.dtype) self.assertIs(t, cpu_t) + cpu_t = t.to(dtype=torch.float64, device=cpu_device) + self.assertEqual(cpu_device, cpu_t.device) + self.assertEqual(cpu_device, t.device) + self.assertEqual(torch.float64, cpu_t.dtype) + self.assertEqual(torch.float32, t.dtype) + self.assertIsNot(t, cpu_t) + cuda_device = torch.device("cuda") cuda_t = t.to("cuda") self.assertEqual(cuda_device, cuda_t.device) self.assertEqual(cpu_device, t.device) + self.assertEqual(torch.float32, cuda_t.dtype) + self.assertEqual(torch.float32, t.dtype) self.assertIsNot(t, cuda_t) cuda_t = t.to(cuda_device) self.assertEqual(cuda_device, cuda_t.device) self.assertEqual(cpu_device, t.device) + self.assertEqual(torch.float32, cuda_t.dtype) + self.assertEqual(torch.float32, t.dtype) + self.assertIsNot(t, cuda_t) + + cuda_t = t.to(dtype=torch.float64, device=cuda_device) + self.assertEqual(cuda_device, cuda_t.device) + self.assertEqual(cpu_device, t.device) + self.assertEqual(torch.float64, cuda_t.dtype) + self.assertEqual(torch.float32, t.dtype) self.assertIsNot(t, cuda_t) cpu_points = torch.rand(9, 3)