diff --git a/pytorch3d/transforms/transform3d.py b/pytorch3d/transforms/transform3d.py index 03ae253a..8cbb3e9d 100644 --- a/pytorch3d/transforms/transform3d.py +++ b/pytorch3d/transforms/transform3d.py @@ -421,8 +421,9 @@ class Transform3d: if self.device != device: other.device = device other._matrix = self._matrix.to(device=device, dtype=dtype) - for t in other._transforms: - t.to(device, copy=copy, dtype=dtype) + other._transforms = [ + t.to(device, copy=copy, dtype=dtype) for t in other._transforms + ] return other def cpu(self): diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 62404d7a..6c2642d8 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -20,13 +20,17 @@ class TestTransform(TestCaseMixin, unittest.TestCase): def test_to(self): tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]])) R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]]) + cpu_points = torch.rand(9, 3) + cuda_points = cpu_points.cuda() R = Rotate(R) t = Transform3d().compose(R, tr) for _ in range(3): - t.cpu() - t.cuda() - t.cuda() - t.cpu() + t = t.cpu() + t.transform_points(cpu_points) + t = t.cuda() + t.transform_points(cuda_points) + t = t.cuda() + t = t.cpu() def test_clone(self): """