Make Transform3d.to() not ignore dtype

Summary: Make Transform3d.to() not ignore a different dtype when device is the same and no copy is requested. Fix other methods where dtype is ignored.

Reviewed By: nikhilaravi

Differential Revision: D28981171

fbshipit-source-id: 4528e6092f4a693aecbe8131ede985fca84e84cf
This commit is contained in:
Patrick Labatut
2021-06-09 15:48:56 -07:00
committed by Facebook GitHub Bot
parent 626bf3fe23
commit 44508ed0db
2 changed files with 39 additions and 10 deletions

View File

@@ -28,23 +28,45 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
cpu_t = t.to("cpu")
self.assertEqual(cpu_device, cpu_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float32, cpu_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIs(t, cpu_t)
cpu_t = t.to(cpu_device)
self.assertEqual(cpu_device, cpu_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float32, cpu_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIs(t, cpu_t)
cpu_t = t.to(dtype=torch.float64, device=cpu_device)
self.assertEqual(cpu_device, cpu_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float64, cpu_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIsNot(t, cpu_t)
cuda_device = torch.device("cuda")
cuda_t = t.to("cuda")
self.assertEqual(cuda_device, cuda_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float32, cuda_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIsNot(t, cuda_t)
cuda_t = t.to(cuda_device)
self.assertEqual(cuda_device, cuda_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float32, cuda_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIsNot(t, cuda_t)
cuda_t = t.to(dtype=torch.float64, device=cuda_device)
self.assertEqual(cuda_device, cuda_t.device)
self.assertEqual(cpu_device, t.device)
self.assertEqual(torch.float64, cuda_t.dtype)
self.assertEqual(torch.float32, t.dtype)
self.assertIsNot(t, cuda_t)
cpu_points = torch.rand(9, 3)