mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
devices for transform3d
Summary: Make `to` on Transform3D carry its member _transforms. Reviewed By: nikhilaravi Differential Revision: D25978611 fbshipit-source-id: 12b39e7a657f28d59ca60800bf9f4193a2c08197
This commit is contained in:
parent
4711665edb
commit
d60c52df4a
@ -421,8 +421,9 @@ class Transform3d:
|
|||||||
if self.device != device:
|
if self.device != device:
|
||||||
other.device = device
|
other.device = device
|
||||||
other._matrix = self._matrix.to(device=device, dtype=dtype)
|
other._matrix = self._matrix.to(device=device, dtype=dtype)
|
||||||
for t in other._transforms:
|
other._transforms = [
|
||||||
t.to(device, copy=copy, dtype=dtype)
|
t.to(device, copy=copy, dtype=dtype) for t in other._transforms
|
||||||
|
]
|
||||||
return other
|
return other
|
||||||
|
|
||||||
def cpu(self):
|
def cpu(self):
|
||||||
|
@ -20,13 +20,17 @@ class TestTransform(TestCaseMixin, unittest.TestCase):
|
|||||||
def test_to(self):
|
def test_to(self):
|
||||||
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
|
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
|
||||||
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
|
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
|
||||||
|
cpu_points = torch.rand(9, 3)
|
||||||
|
cuda_points = cpu_points.cuda()
|
||||||
R = Rotate(R)
|
R = Rotate(R)
|
||||||
t = Transform3d().compose(R, tr)
|
t = Transform3d().compose(R, tr)
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
t.cpu()
|
t = t.cpu()
|
||||||
t.cuda()
|
t.transform_points(cpu_points)
|
||||||
t.cuda()
|
t = t.cuda()
|
||||||
t.cpu()
|
t.transform_points(cuda_points)
|
||||||
|
t = t.cuda()
|
||||||
|
t = t.cpu()
|
||||||
|
|
||||||
def test_clone(self):
|
def test_clone(self):
|
||||||
"""
|
"""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user