devices for transform3d

Summary: Make `to` on Transform3D carry its member _transforms.

Reviewed By: nikhilaravi

Differential Revision: D25978611

fbshipit-source-id: 12b39e7a657f28d59ca60800bf9f4193a2c08197
This commit is contained in:
Jeremy Reizenstein
2021-01-21 04:56:56 -08:00
committed by Facebook GitHub Bot
parent 4711665edb
commit d60c52df4a
2 changed files with 11 additions and 6 deletions

View File

@@ -421,8 +421,9 @@ class Transform3d:
if self.device != device:
other.device = device
other._matrix = self._matrix.to(device=device, dtype=dtype)
for t in other._transforms:
t.to(device, copy=copy, dtype=dtype)
other._transforms = [
t.to(device, copy=copy, dtype=dtype) for t in other._transforms
]
return other
def cpu(self):