mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-19 22:00:35 +08:00
Tidy uses of torch.device in Transform3d
Summary: Tidy uses of `torch.device` in `Transforms3d`: - Allow `str` or `torch.device` in user-facing methods - Consistently use `torch.device` for internal types - Fix comparison of devices Reviewed By: nikhilaravi Differential Revision: D28929486 fbshipit-source-id: bd1d6cc7ede3d8fd549fd3224a9b07eec53f8164
This commit is contained in:
committed by
Facebook GitHub Bot
parent
48faf8eb7e
commit
13a0110b69
@@ -20,10 +20,35 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
||||
def test_to(self):
|
||||
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
|
||||
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
|
||||
cpu_points = torch.rand(9, 3)
|
||||
cuda_points = cpu_points.cuda()
|
||||
R = Rotate(R)
|
||||
t = Transform3d().compose(R, tr)
|
||||
|
||||
cpu_device = torch.device("cpu")
|
||||
|
||||
cpu_t = t.to("cpu")
|
||||
self.assertEqual(cpu_device, cpu_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertIs(t, cpu_t)
|
||||
|
||||
cpu_t = t.to(cpu_device)
|
||||
self.assertEqual(cpu_device, cpu_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertIs(t, cpu_t)
|
||||
|
||||
cuda_device = torch.device("cuda")
|
||||
|
||||
cuda_t = t.to("cuda")
|
||||
self.assertEqual(cuda_device, cuda_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertIsNot(t, cuda_t)
|
||||
|
||||
cuda_t = t.to(cuda_device)
|
||||
self.assertEqual(cuda_device, cuda_t.device)
|
||||
self.assertEqual(cpu_device, t.device)
|
||||
self.assertIsNot(t, cuda_t)
|
||||
|
||||
cpu_points = torch.rand(9, 3)
|
||||
cuda_points = cpu_points.cuda()
|
||||
for _ in range(3):
|
||||
t = t.cpu()
|
||||
t.transform_points(cpu_points)
|
||||
|
||||
Reference in New Issue
Block a user