detach for meshes, pointclouds, textures

Summary: Add `detach` for Meshes, Pointclouds, Textures

Reviewed By: nikhilaravi

Differential Revision: D23070418

fbshipit-source-id: 68671124ce114c4495d7ef3c944c9aac3d0db2d8
This commit is contained in:
Georgia Gkioxari
2020-08-17 14:53:56 -07:00
committed by Facebook GitHub Bot
parent 5852b74d12
commit 7f2f95f225
6 changed files with 283 additions and 8 deletions

View File

@@ -24,6 +24,7 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
with_normals: bool = True,
with_features: bool = True,
min_points: int = 0,
requires_grad: bool = False,
):
"""
Function to generate a Pointclouds object of N meshes with
@@ -49,16 +50,31 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
p.fill_(p[0])
points_list = [
torch.rand((i, 3), device=device, dtype=torch.float32) for i in p
torch.rand(
(i, 3), device=device, dtype=torch.float32, requires_grad=requires_grad
)
for i in p
]
normals_list, features_list = None, None
if with_normals:
normals_list = [
torch.rand((i, 3), device=device, dtype=torch.float32) for i in p
torch.rand(
(i, 3),
device=device,
dtype=torch.float32,
requires_grad=requires_grad,
)
for i in p
]
if with_features:
features_list = [
torch.rand((i, channels), device=device, dtype=torch.float32) for i in p
torch.rand(
(i, channels),
device=device,
dtype=torch.float32,
requires_grad=requires_grad,
)
for i in p
]
if lists_to_tensors:
@@ -382,6 +398,39 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
self.assertCloudsEqual(clouds, new_clouds)
def test_detach(self):
N = 5
for lists_to_tensors in (True, False):
clouds = self.init_cloud(
N, 100, 5, lists_to_tensors=lists_to_tensors, requires_grad=True
)
for force in (False, True):
if force:
clouds.points_packed()
new_clouds = clouds.detach()
for cloud in new_clouds.points_list():
self.assertTrue(cloud.requires_grad == False)
for normal in new_clouds.normals_list():
self.assertTrue(normal.requires_grad == False)
for feats in new_clouds.features_list():
self.assertTrue(feats.requires_grad == False)
for attrib in [
"points_packed",
"normals_packed",
"features_packed",
"points_padded",
"normals_padded",
"features_padded",
]:
self.assertTrue(
getattr(new_clouds, attrib)().requires_grad == False
)
self.assertCloudsEqual(clouds, new_clouds)
def assertCloudsEqual(self, cloud1, cloud2):
N = len(cloud1)
self.assertEqual(N, len(cloud2))