detach for meshes, pointclouds, textures

Summary: Add `detach` for Meshes, Pointclouds, Textures

Reviewed By: nikhilaravi

Differential Revision: D23070418

fbshipit-source-id: 68671124ce114c4495d7ef3c944c9aac3d0db2d8
This commit is contained in:
Georgia Gkioxari
2020-08-17 14:53:56 -07:00
committed by Facebook GitHub Bot
parent 5852b74d12
commit 7f2f95f225
6 changed files with 283 additions and 8 deletions

View File

@@ -1138,6 +1138,28 @@ class Meshes(object):
other.textures = self.textures.clone()
return other
def detach(self):
"""
Detach Meshes object. All internal tensors are detached individually.
Returns:
new Meshes object.
"""
verts_list = self.verts_list()
faces_list = self.faces_list()
new_verts_list = [v.detach() for v in verts_list]
new_faces_list = [f.detach() for f in faces_list]
other = self.__class__(verts=new_verts_list, faces=new_faces_list)
for k in self._INTERNAL_TENSORS:
v = getattr(self, k)
if torch.is_tensor(v):
setattr(other, k, v.detach())
# Textures is not a tensor but has a detach method
if self.textures is not None:
other.textures = self.textures.detach()
return other
def to(self, device, copy: bool = False):
"""
Match functionality of torch.Tensor.to()

View File

@@ -655,6 +655,42 @@ class Pointclouds(object):
setattr(other, k, v.clone())
return other
def detach(self):
"""
Detach Pointclouds object. All internal tensors are detached
individually.
Returns:
new Pointclouds object.
"""
# instantiate new pointcloud with the representation which is not None
# (either list or tensor) to save compute.
new_points, new_normals, new_features = None, None, None
if self._points_list is not None:
new_points = [v.detach() for v in self.points_list()]
normals_list = self.normals_list()
features_list = self.features_list()
if normals_list is not None:
new_normals = [n.detach() for n in normals_list]
if features_list is not None:
new_features = [f.detach() for f in features_list]
elif self._points_padded is not None:
new_points = self.points_padded().detach()
normals_padded = self.normals_padded()
features_padded = self.features_padded()
if normals_padded is not None:
new_normals = self.normals_padded().detach()
if features_padded is not None:
new_features = self.features_padded().detach()
other = self.__class__(
points=new_points, normals=new_normals, features=new_features
)
for k in self._INTERNAL_TENSORS:
v = getattr(self, k)
if torch.is_tensor(v):
setattr(other, k, v.detach())
return other
def to(self, device, copy: bool = False):
"""
Match functionality of torch.Tensor.to()