diff --git a/pytorch3d/renderer/mesh/textures.py b/pytorch3d/renderer/mesh/textures.py index b4143bc3..dab5aaff 100644 --- a/pytorch3d/renderer/mesh/textures.py +++ b/pytorch3d/renderer/mesh/textures.py @@ -242,6 +242,13 @@ class TexturesBase(object): """ raise NotImplementedError() + def detach(self): + """ + Each texture class should implement a method + to detach all necessary internal tensors. + """ + raise NotImplementedError() + def __getitem__(self, index): """ Each texture class should implement a method @@ -388,6 +395,8 @@ class TexturesAtlas(TexturesBase): def clone(self): tex = self.__class__(atlas=self.atlas_padded().clone()) + if self._atlas_list is not None: + tex._atlas_list = [atlas.clone() for atlas in self._atlas_list] num_faces = ( self._num_faces_per_mesh.clone() if torch.is_tensor(self._num_faces_per_mesh) @@ -397,6 +406,19 @@ class TexturesAtlas(TexturesBase): tex._num_faces_per_mesh = num_faces return tex + def detach(self): + tex = self.__class__(atlas=self.atlas_padded().detach()) + if self._atlas_list is not None: + tex._atlas_list = [atlas.detach() for atlas in self._atlas_list] + num_faces = ( + self._num_faces_per_mesh.detach() + if torch.is_tensor(self._num_faces_per_mesh) + else self._num_faces_per_mesh + ) + tex.valid = self.valid.detach() + tex._num_faces_per_mesh = num_faces + return tex + def __getitem__(self, index): props = ["atlas_list", "_num_faces_per_mesh"] new_props = self._getitem(index, props=props) @@ -656,6 +678,12 @@ class TexturesUV(TexturesBase): self.faces_uvs_padded().clone(), self.verts_uvs_padded().clone(), ) + if self._maps_list is not None: + tex._maps_list = [m.clone() for m in self._maps_list] + if self._verts_uvs_list is not None: + tex._verts_uvs_list = [v.clone() for v in self._verts_uvs_list] + if self._faces_uvs_list is not None: + tex._faces_uvs_list = [f.clone() for f in self._faces_uvs_list] num_faces = ( self._num_faces_per_mesh.clone() if torch.is_tensor(self._num_faces_per_mesh) @@ -665,6 +693,27 @@ class TexturesUV(TexturesBase): tex.valid = self.valid.clone() return tex + def detach(self): + tex = self.__class__( + self.maps_padded().detach(), + self.faces_uvs_padded().detach(), + self.verts_uvs_padded().detach(), + ) + if self._maps_list is not None: + tex._maps_list = [m.detach() for m in self._maps_list] + if self._verts_uvs_list is not None: + tex._verts_uvs_list = [v.detach() for v in self._verts_uvs_list] + if self._faces_uvs_list is not None: + tex._faces_uvs_list = [f.detach() for f in self._faces_uvs_list] + num_faces = ( + self._num_faces_per_mesh.detach() + if torch.is_tensor(self._num_faces_per_mesh) + else self._num_faces_per_mesh + ) + tex._num_faces_per_mesh = num_faces + tex.valid = self.valid.detach() + return tex + def __getitem__(self, index): props = ["verts_uvs_list", "faces_uvs_list", "maps_list", "_num_faces_per_mesh"] new_props = self._getitem(index, props) @@ -892,8 +941,8 @@ class TexturesVertex(TexturesBase): has a D dimensional feature vector. Args: - verts_features: (N, V, D) tensor giving a feature vector with - artbitrary dimensions for each vertex. + verts_features: list of (Vi, D) or (N, V, D) tensor giving a feature + vector with artbitrary dimensions for each vertex. """ if isinstance(verts_features, (tuple, list)): correct_shape = all( @@ -948,15 +997,28 @@ class TexturesVertex(TexturesBase): tex = self.__class__(self.verts_features_padded().clone()) if self._verts_features_list is not None: tex._verts_features_list = [f.clone() for f in self._verts_features_list] - num_faces = ( + num_verts = ( self._num_verts_per_mesh.clone() if torch.is_tensor(self._num_verts_per_mesh) else self._num_verts_per_mesh ) - tex._num_verts_per_mesh = num_faces + tex._num_verts_per_mesh = num_verts tex.valid = self.valid.clone() return tex + def detach(self): + tex = self.__class__(self.verts_features_padded().detach()) + if self._verts_features_list is not None: + tex._verts_features_list = [f.detach() for f in self._verts_features_list] + num_verts = ( + self._num_verts_per_mesh.detach() + if torch.is_tensor(self._num_verts_per_mesh) + else self._num_verts_per_mesh + ) + tex._num_verts_per_mesh = num_verts + tex.valid = self.valid.detach() + return tex + def __getitem__(self, index): props = ["verts_features_list", "_num_verts_per_mesh"] new_props = self._getitem(index, props) diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index 2bcd6a56..4e4c1b0a 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -1138,6 +1138,28 @@ class Meshes(object): other.textures = self.textures.clone() return other + def detach(self): + """ + Detach Meshes object. All internal tensors are detached individually. + + Returns: + new Meshes object. + """ + verts_list = self.verts_list() + faces_list = self.faces_list() + new_verts_list = [v.detach() for v in verts_list] + new_faces_list = [f.detach() for f in faces_list] + other = self.__class__(verts=new_verts_list, faces=new_faces_list) + for k in self._INTERNAL_TENSORS: + v = getattr(self, k) + if torch.is_tensor(v): + setattr(other, k, v.detach()) + + # Textures is not a tensor but has a detach method + if self.textures is not None: + other.textures = self.textures.detach() + return other + def to(self, device, copy: bool = False): """ Match functionality of torch.Tensor.to() diff --git a/pytorch3d/structures/pointclouds.py b/pytorch3d/structures/pointclouds.py index 7f5a0e4a..74e3e71a 100644 --- a/pytorch3d/structures/pointclouds.py +++ b/pytorch3d/structures/pointclouds.py @@ -655,6 +655,42 @@ class Pointclouds(object): setattr(other, k, v.clone()) return other + def detach(self): + """ + Detach Pointclouds object. All internal tensors are detached + individually. + + Returns: + new Pointclouds object. + """ + # instantiate new pointcloud with the representation which is not None + # (either list or tensor) to save compute. + new_points, new_normals, new_features = None, None, None + if self._points_list is not None: + new_points = [v.detach() for v in self.points_list()] + normals_list = self.normals_list() + features_list = self.features_list() + if normals_list is not None: + new_normals = [n.detach() for n in normals_list] + if features_list is not None: + new_features = [f.detach() for f in features_list] + elif self._points_padded is not None: + new_points = self.points_padded().detach() + normals_padded = self.normals_padded() + features_padded = self.features_padded() + if normals_padded is not None: + new_normals = self.normals_padded().detach() + if features_padded is not None: + new_features = self.features_padded().detach() + other = self.__class__( + points=new_points, normals=new_normals, features=new_features + ) + for k in self._INTERNAL_TENSORS: + v = getattr(self, k) + if torch.is_tensor(v): + setattr(other, k, v.detach()) + return other + def to(self, device, copy: bool = False): """ Match functionality of torch.Tensor.to() diff --git a/tests/test_meshes.py b/tests/test_meshes.py index 3e3904e5..2117fcb9 100644 --- a/tests/test_meshes.py +++ b/tests/test_meshes.py @@ -20,6 +20,7 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): max_f: int = 300, lists_to_tensors: bool = False, device: str = "cpu", + requires_grad: bool = False, ): """ Function to generate a Meshes object of N meshes with @@ -57,7 +58,12 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): # Generate the actual vertices and faces. for i in range(num_meshes): - verts = torch.rand((v[i], 3), dtype=torch.float32, device=device) + verts = torch.rand( + (v[i], 3), + dtype=torch.float32, + device=device, + requires_grad=requires_grad, + ) faces = torch.randint( v[i], size=(f[i], 3), dtype=torch.int64, device=device ) @@ -353,6 +359,26 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): self.assertSeparate(new_mesh.faces_padded(), mesh.faces_padded()) self.assertSeparate(new_mesh.edges_packed(), mesh.edges_packed()) + def test_detach(self): + N = 5 + mesh = TestMeshes.init_mesh(N, 10, 100, requires_grad=True) + for force in [0, 1]: + if force: + # force mesh to have computed attributes + mesh.verts_packed() + mesh.edges_packed() + mesh.verts_padded() + + new_mesh = mesh.detach() + + self.assertFalse(new_mesh.verts_packed().requires_grad) + self.assertClose(new_mesh.verts_packed(), mesh.verts_packed()) + self.assertTrue(new_mesh.verts_padded().requires_grad == False) + self.assertClose(new_mesh.verts_padded(), mesh.verts_padded()) + for v, newv in zip(mesh.verts_list(), new_mesh.verts_list()): + self.assertTrue(newv.requires_grad == False) + self.assertClose(newv, v) + def test_laplacian_packed(self): def naive_laplacian_packed(meshes): verts_packed = meshes.verts_packed() diff --git a/tests/test_pointclouds.py b/tests/test_pointclouds.py index 0239fedd..e1b77d2e 100644 --- a/tests/test_pointclouds.py +++ b/tests/test_pointclouds.py @@ -24,6 +24,7 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): with_normals: bool = True, with_features: bool = True, min_points: int = 0, + requires_grad: bool = False, ): """ Function to generate a Pointclouds object of N meshes with @@ -49,16 +50,31 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): p.fill_(p[0]) points_list = [ - torch.rand((i, 3), device=device, dtype=torch.float32) for i in p + torch.rand( + (i, 3), device=device, dtype=torch.float32, requires_grad=requires_grad + ) + for i in p ] normals_list, features_list = None, None if with_normals: normals_list = [ - torch.rand((i, 3), device=device, dtype=torch.float32) for i in p + torch.rand( + (i, 3), + device=device, + dtype=torch.float32, + requires_grad=requires_grad, + ) + for i in p ] if with_features: features_list = [ - torch.rand((i, channels), device=device, dtype=torch.float32) for i in p + torch.rand( + (i, channels), + device=device, + dtype=torch.float32, + requires_grad=requires_grad, + ) + for i in p ] if lists_to_tensors: @@ -382,6 +398,39 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase): self.assertCloudsEqual(clouds, new_clouds) + def test_detach(self): + N = 5 + for lists_to_tensors in (True, False): + clouds = self.init_cloud( + N, 100, 5, lists_to_tensors=lists_to_tensors, requires_grad=True + ) + for force in (False, True): + if force: + clouds.points_packed() + + new_clouds = clouds.detach() + + for cloud in new_clouds.points_list(): + self.assertTrue(cloud.requires_grad == False) + for normal in new_clouds.normals_list(): + self.assertTrue(normal.requires_grad == False) + for feats in new_clouds.features_list(): + self.assertTrue(feats.requires_grad == False) + + for attrib in [ + "points_packed", + "normals_packed", + "features_packed", + "points_padded", + "normals_padded", + "features_padded", + ]: + self.assertTrue( + getattr(new_clouds, attrib)().requires_grad == False + ) + + self.assertCloudsEqual(clouds, new_clouds) + def assertCloudsEqual(self, cloud1, cloud2): N = len(cloud1) self.assertEqual(N, len(cloud2)) diff --git a/tests/test_texturing.py b/tests/test_texturing.py index e71ff760..bcab7ac6 100644 --- a/tests/test_texturing.py +++ b/tests/test_texturing.py @@ -113,11 +113,37 @@ class TestTexturesVertex(TestCaseMixin, unittest.TestCase): def test_clone(self): tex = TexturesVertex(verts_features=torch.rand(size=(10, 100, 128))) + tex.verts_features_list() tex_cloned = tex.clone() self.assertSeparate( tex._verts_features_padded, tex_cloned._verts_features_padded ) + self.assertClose(tex._verts_features_padded, tex_cloned._verts_features_padded) self.assertSeparate(tex.valid, tex_cloned.valid) + self.assertTrue(tex.valid.eq(tex_cloned.valid).all()) + for i in range(tex._N): + self.assertSeparate( + tex._verts_features_list[i], tex_cloned._verts_features_list[i] + ) + self.assertClose( + tex._verts_features_list[i], tex_cloned._verts_features_list[i] + ) + + def test_detach(self): + tex = TexturesVertex( + verts_features=torch.rand(size=(10, 100, 128), requires_grad=True) + ) + tex.verts_features_list() + tex_detached = tex.detach() + self.assertFalse(tex_detached._verts_features_padded.requires_grad) + self.assertClose( + tex_detached._verts_features_padded, tex._verts_features_padded + ) + for i in range(tex._N): + self.assertClose( + tex._verts_features_list[i], tex_detached._verts_features_list[i] + ) + self.assertFalse(tex_detached._verts_features_list[i].requires_grad) def test_extend(self): B = 10 @@ -278,9 +304,25 @@ class TestTexturesAtlas(TestCaseMixin, unittest.TestCase): def test_clone(self): tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3))) + tex.atlas_list() tex_cloned = tex.clone() self.assertSeparate(tex._atlas_padded, tex_cloned._atlas_padded) + self.assertClose(tex._atlas_padded, tex_cloned._atlas_padded) self.assertSeparate(tex.valid, tex_cloned.valid) + self.assertTrue(tex.valid.eq(tex_cloned.valid).all()) + for i in range(tex._N): + self.assertSeparate(tex._atlas_list[i], tex_cloned._atlas_list[i]) + self.assertClose(tex._atlas_list[i], tex_cloned._atlas_list[i]) + + def test_detach(self): + tex = TexturesAtlas(atlas=torch.rand(size=(1, 10, 2, 2, 3), requires_grad=True)) + tex.atlas_list() + tex_detached = tex.detach() + self.assertFalse(tex_detached._atlas_padded.requires_grad) + self.assertClose(tex_detached._atlas_padded, tex._atlas_padded) + for i in range(tex._N): + self.assertFalse(tex_detached._atlas_list[i].requires_grad) + self.assertClose(tex._atlas_list[i], tex_detached._atlas_list[i]) def test_extend(self): B = 10 @@ -478,11 +520,49 @@ class TestTexturesUV(TestCaseMixin, unittest.TestCase): faces_uvs=torch.rand(size=(5, 10, 3)), verts_uvs=torch.rand(size=(5, 15, 2)), ) + tex.faces_uvs_list() + tex.verts_uvs_list() tex_cloned = tex.clone() self.assertSeparate(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded) + self.assertClose(tex._faces_uvs_padded, tex_cloned._faces_uvs_padded) self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded) + self.assertClose(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded) self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded) + self.assertClose(tex._maps_padded, tex_cloned._maps_padded) self.assertSeparate(tex.valid, tex_cloned.valid) + self.assertTrue(tex.valid.eq(tex_cloned.valid).all()) + for i in range(tex._N): + self.assertSeparate(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i]) + self.assertClose(tex._faces_uvs_list[i], tex_cloned._faces_uvs_list[i]) + self.assertSeparate(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i]) + self.assertClose(tex._verts_uvs_list[i], tex_cloned._verts_uvs_list[i]) + # tex._maps_list is not use anywhere so it's not stored. We call it explicitly + self.assertSeparate(tex.maps_list()[i], tex_cloned.maps_list()[i]) + self.assertClose(tex.maps_list()[i], tex_cloned.maps_list()[i]) + + def test_detach(self): + tex = TexturesUV( + maps=torch.ones((5, 16, 16, 3), requires_grad=True), + faces_uvs=torch.rand(size=(5, 10, 3)), + verts_uvs=torch.rand(size=(5, 15, 2)), + ) + tex.faces_uvs_list() + tex.verts_uvs_list() + tex_detached = tex.detach() + self.assertFalse(tex_detached._maps_padded.requires_grad) + self.assertClose(tex._maps_padded, tex_detached._maps_padded) + self.assertFalse(tex_detached._verts_uvs_padded.requires_grad) + self.assertClose(tex._verts_uvs_padded, tex_detached._verts_uvs_padded) + self.assertFalse(tex_detached._faces_uvs_padded.requires_grad) + self.assertClose(tex._faces_uvs_padded, tex_detached._faces_uvs_padded) + for i in range(tex._N): + self.assertFalse(tex_detached._verts_uvs_list[i].requires_grad) + self.assertClose(tex._verts_uvs_list[i], tex_detached._verts_uvs_list[i]) + self.assertFalse(tex_detached._faces_uvs_list[i].requires_grad) + self.assertClose(tex._faces_uvs_list[i], tex_detached._faces_uvs_list[i]) + # tex._maps_list is not use anywhere so it's not stored. We call it explicitly + self.assertFalse(tex_detached.maps_list()[i].requires_grad) + self.assertClose(tex.maps_list()[i], tex_detached.maps_list()[i]) def test_extend(self): B = 5