detach for meshes, pointclouds, textures

Summary: Add `detach` for Meshes, Pointclouds, Textures

Reviewed By: nikhilaravi

Differential Revision: D23070418

fbshipit-source-id: 68671124ce114c4495d7ef3c944c9aac3d0db2d8
This commit is contained in:
Georgia Gkioxari
2020-08-17 14:53:56 -07:00
committed by Facebook GitHub Bot
parent 5852b74d12
commit 7f2f95f225
6 changed files with 283 additions and 8 deletions

View File

@@ -113,11 +113,37 @@ class TestTexturesVertex(TestCaseMixin, unittest.TestCase):
def test_clone(self):
tex = TexturesVertex(verts_features=torch.rand(size=(10, 100, 128)))
tex.verts_features_list()
tex_cloned = tex.clone()
self.assertSeparate(
tex._verts_features_padded, tex_cloned._verts_features_padded
)
self.assertClose(tex._verts_features_padded, tex_cloned._verts_features_padded)
self.assertSeparate(tex.valid, tex_cloned.valid)
self.assertTrue(tex.valid.eq(tex_cloned.valid).all())
for i in range(tex._N):
self.assertSeparate(
tex._verts_features_list[i], tex_cloned._verts_features_list[i]
)
self.assertClose(
tex._verts_features_list[i], tex_cloned._verts_features_list[i]
)
def test_detach(self):
tex = TexturesVertex(
verts_features=torch.rand(size=(10, 100, 128), requires_grad=True)
)
tex.verts_features_list()
tex_detached = tex.detach()
self.assertFalse(tex_detached._verts_features_padded.requires_grad)
self.assertClose(
tex_detached._verts_features_padded, tex._verts_features_padded
)
for i in range(tex._N):
self.assertClose(
tex._verts_features_list[i], tex_detached._verts_features_list[i]
)
self.assertFalse(tex_detached._verts_features_list[i].requires_grad)
def test_extend(self):
B = 10
@@ -278,9 +304,25 @@ class TestTexturesAtlas(TestCaseMixin, unittest.TestCase):
def test_clone(self):
tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3)))
tex.atlas_list()
tex_cloned = tex.clone()
self.assertSeparate(tex._atlas_padded, tex_cloned._atlas_padded)
self.assertClose(tex._atlas_padded, tex_cloned._atlas_padded)
self.assertSeparate(tex.valid, tex_cloned.valid)
self.assertTrue(tex.valid.eq(tex_cloned.valid).all())
for i in range(tex._N):
self.assertSeparate(tex._atlas_list[i], tex_cloned._atlas_list[i])
self.assertClose(tex._atlas_list[i], tex_cloned._atlas_list[i])
def test_detach(self):
tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3), requires_grad=True))
tex.atlas_list()
tex_detached = tex.detach()
self.assertFalse(tex_detached._atlas_padded.requires_grad)
self.assertClose(tex_detached._atlas_padded, tex._atlas_padded)
for i in range(tex._N):
self.assertFalse(tex_detached._atlas_list[i].requires_grad)
self.assertClose(tex._atlas_list[i], tex_detached._atlas_list[i])
def test_extend(self):
B = 10
@@ -478,11 +520,49 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase):
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
tex.faces_uvs_list()
tex.verts_uvs_list()
tex_cloned = tex.clone()
self.assertSeparate(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
self.assertClose(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded)
self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
self.assertClose(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded)
self.assertClose(tex._maps_padded, tex_cloned._maps_padded)
self.assertSeparate(tex.valid, tex_cloned.valid)
self.assertTrue(tex.valid.eq(tex_cloned.valid).all())
for i in range(tex._N):
self.assertSeparate(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i])
self.assertClose(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i])
self.assertSeparate(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i])
self.assertClose(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i])
# tex._maps_list is not use anywhere so it's not stored. We call it explicitly
self.assertSeparate(tex.maps_list()[i], tex_cloned.maps_list()[i])
self.assertClose(tex.maps_list()[i], tex_cloned.maps_list()[i])
def test_detach(self):
tex = TexturesUV(
maps=torch.ones((5, 16, 16, 3), requires_grad=True),
faces_uvs=torch.rand(size=(5, 10, 3)),
verts_uvs=torch.rand(size=(5, 15, 2)),
)
tex.faces_uvs_list()
tex.verts_uvs_list()
tex_detached = tex.detach()
self.assertFalse(tex_detached._maps_padded.requires_grad)
self.assertClose(tex._maps_padded, tex_detached._maps_padded)
self.assertFalse(tex_detached._verts_uvs_padded.requires_grad)
self.assertClose(tex._verts_uvs_padded, tex_detached._verts_uvs_padded)
self.assertFalse(tex_detached._faces_uvs_padded.requires_grad)
self.assertClose(tex._faces_uvs_padded, tex_detached._faces_uvs_padded)
for i in range(tex._N):
self.assertFalse(tex_detached._verts_uvs_list[i].requires_grad)
self.assertClose(tex._verts_uvs_list[i], tex_detached._verts_uvs_list[i])
self.assertFalse(tex_detached._faces_uvs_list[i].requires_grad)
self.assertClose(tex._faces_uvs_list[i], tex_detached._faces_uvs_list[i])
# tex._maps_list is not use anywhere so it's not stored. We call it explicitly
self.assertFalse(tex_detached.maps_list()[i].requires_grad)
self.assertClose(tex.maps_list()[i], tex_detached.maps_list()[i])
def test_extend(self):
B = 5