devices for transform3d

Summary: Make `to` on Transform3D carry its member _transforms.

Reviewed By: nikhilaravi

Differential Revision: D25978611

fbshipit-source-id: 12b39e7a657f28d59ca60800bf9f4193a2c08197
This commit is contained in:
Jeremy Reizenstein 2021-01-21 04:56:56 -08:00 committed by Facebook GitHub Bot
parent 4711665edb
commit d60c52df4a
2 changed files with 11 additions and 6 deletions

View File

@ -421,8 +421,9 @@ class Transform3d:
if self.device != device:
other.device = device
other._matrix = self._matrix.to(device=device, dtype=dtype)
for t in other._transforms:
t.to(device, copy=copy, dtype=dtype)
other._transforms = [
t.to(device, copy=copy, dtype=dtype) for t in other._transforms
]
return other
def cpu(self):

View File

@ -20,13 +20,17 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
def test_to(self):
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
cpu_points = torch.rand(9, 3)
cuda_points = cpu_points.cuda()
R = Rotate(R)
t = Transform3d().compose(R, tr)
for _ in range(3):
t.cpu()
t.cuda()
t.cuda()
t.cpu()
t = t.cpu()
t.transform_points(cpu_points)
t = t.cuda()
t.transform_points(cuda_points)
t = t.cuda()
t = t.cpu()
def test_clone(self):
"""