mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
Make Transform3d.to() not ignore dtype
Summary: Make Transform3d.to() not ignore a different dtype when device is the same and no copy is requested. Fix other methods where dtype is ignored. Reviewed By: nikhilaravi Differential Revision: D28981171 fbshipit-source-id: 4528e6092f4a693aecbe8131ede985fca84e84cf
This commit is contained in:
committed by
Facebook GitHub Bot
parent
626bf3fe23
commit
44508ed0db
@@ -28,23 +28,45 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
||||
cpu_t = t.to("cpu")
|
||||
self.assertEqual(cpu_device, cpu_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertEqual(torch.float32, cpu_t.dtype)
|
||||
self.assertEqual(torch.float32, t.dtype)
|
||||
self.assertIs(t, cpu_t)
|
||||
|
||||
cpu_t = t.to(cpu_device)
|
||||
self.assertEqual(cpu_device, cpu_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertEqual(torch.float32, cpu_t.dtype)
|
||||
self.assertEqual(torch.float32, t.dtype)
|
||||
self.assertIs(t, cpu_t)
|
||||
|
||||
cpu_t = t.to(dtype=torch.float64, device=cpu_device)
|
||||
self.assertEqual(cpu_device, cpu_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertEqual(torch.float64, cpu_t.dtype)
|
||||
self.assertEqual(torch.float32, t.dtype)
|
||||
self.assertIsNot(t, cpu_t)
|
||||
|
||||
cuda_device = torch.device("cuda")
|
||||
|
||||
cuda_t = t.to("cuda")
|
||||
self.assertEqual(cuda_device, cuda_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertEqual(torch.float32, cuda_t.dtype)
|
||||
self.assertEqual(torch.float32, t.dtype)
|
||||
self.assertIsNot(t, cuda_t)
|
||||
|
||||
cuda_t = t.to(cuda_device)
|
||||
self.assertEqual(cuda_device, cuda_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertEqual(torch.float32, cuda_t.dtype)
|
||||
self.assertEqual(torch.float32, t.dtype)
|
||||
self.assertIsNot(t, cuda_t)
|
||||
|
||||
cuda_t = t.to(dtype=torch.float64, device=cuda_device)
|
||||
self.assertEqual(cuda_device, cuda_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertEqual(torch.float64, cuda_t.dtype)
|
||||
self.assertEqual(torch.float32, t.dtype)
|
||||
self.assertIsNot(t, cuda_t)
|
||||
|
||||
cpu_points = torch.rand(9, 3)
|
||||
|
||||
Reference in New Issue
Block a user