mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
detach for meshes, pointclouds, textures
Summary: Add `detach` for Meshes, Pointclouds, Textures Reviewed By: nikhilaravi Differential Revision: D23070418 fbshipit-source-id: 68671124ce114c4495d7ef3c944c9aac3d0db2d8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5852b74d12
commit
7f2f95f225
@@ -24,6 +24,7 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
||||
with_normals: bool = True,
|
||||
with_features: bool = True,
|
||||
min_points: int = 0,
|
||||
requires_grad: bool = False,
|
||||
):
|
||||
"""
|
||||
Function to generate a Pointclouds object of N meshes with
|
||||
@@ -49,16 +50,31 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
||||
p.fill_(p[0])
|
||||
|
||||
points_list = [
|
||||
torch.rand((i, 3), device=device, dtype=torch.float32) for i in p
|
||||
torch.rand(
|
||||
(i, 3), device=device, dtype=torch.float32, requires_grad=requires_grad
|
||||
)
|
||||
for i in p
|
||||
]
|
||||
normals_list, features_list = None, None
|
||||
if with_normals:
|
||||
normals_list = [
|
||||
torch.rand((i, 3), device=device, dtype=torch.float32) for i in p
|
||||
torch.rand(
|
||||
(i, 3),
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
for i in p
|
||||
]
|
||||
if with_features:
|
||||
features_list = [
|
||||
torch.rand((i, channels), device=device, dtype=torch.float32) for i in p
|
||||
torch.rand(
|
||||
(i, channels),
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
for i in p
|
||||
]
|
||||
|
||||
if lists_to_tensors:
|
||||
@@ -382,6 +398,39 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
self.assertCloudsEqual(clouds, new_clouds)
|
||||
|
||||
def test_detach(self):
|
||||
N = 5
|
||||
for lists_to_tensors in (True, False):
|
||||
clouds = self.init_cloud(
|
||||
N, 100, 5, lists_to_tensors=lists_to_tensors, requires_grad=True
|
||||
)
|
||||
for force in (False, True):
|
||||
if force:
|
||||
clouds.points_packed()
|
||||
|
||||
new_clouds = clouds.detach()
|
||||
|
||||
for cloud in new_clouds.points_list():
|
||||
self.assertTrue(cloud.requires_grad == False)
|
||||
for normal in new_clouds.normals_list():
|
||||
self.assertTrue(normal.requires_grad == False)
|
||||
for feats in new_clouds.features_list():
|
||||
self.assertTrue(feats.requires_grad == False)
|
||||
|
||||
for attrib in [
|
||||
"points_packed",
|
||||
"normals_packed",
|
||||
"features_packed",
|
||||
"points_padded",
|
||||
"normals_padded",
|
||||
"features_padded",
|
||||
]:
|
||||
self.assertTrue(
|
||||
getattr(new_clouds, attrib)().requires_grad == False
|
||||
)
|
||||
|
||||
self.assertCloudsEqual(clouds, new_clouds)
|
||||
|
||||
def assertCloudsEqual(self, cloud1, cloud2):
|
||||
N = len(cloud1)
|
||||
self.assertEqual(N, len(cloud2))
|
||||
|
||||
Reference in New Issue
Block a user