mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
detach for meshes, pointclouds, textures
Summary: Add `detach` for Meshes, Pointclouds, Textures Reviewed By: nikhilaravi Differential Revision: D23070418 fbshipit-source-id: 68671124ce114c4495d7ef3c944c9aac3d0db2d8
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5852b74d12
commit
7f2f95f225
@@ -242,6 +242,13 @@ class TexturesBase(object):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def detach(self):
|
||||
"""
|
||||
Each texture class should implement a method
|
||||
to detach all necessary internal tensors.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __getitem__(self, index):
|
||||
"""
|
||||
Each texture class should implement a method
|
||||
@@ -388,6 +395,8 @@ class TexturesAtlas(TexturesBase):
|
||||
|
||||
def clone(self):
|
||||
tex = self.__class__(atlas=self.atlas_padded().clone())
|
||||
if self._atlas_list is not None:
|
||||
tex._atlas_list = [atlas.clone() for atlas in self._atlas_list]
|
||||
num_faces = (
|
||||
self._num_faces_per_mesh.clone()
|
||||
if torch.is_tensor(self._num_faces_per_mesh)
|
||||
@@ -397,6 +406,19 @@ class TexturesAtlas(TexturesBase):
|
||||
tex._num_faces_per_mesh = num_faces
|
||||
return tex
|
||||
|
||||
def detach(self):
|
||||
tex = self.__class__(atlas=self.atlas_padded().detach())
|
||||
if self._atlas_list is not None:
|
||||
tex._atlas_list = [atlas.detach() for atlas in self._atlas_list]
|
||||
num_faces = (
|
||||
self._num_faces_per_mesh.detach()
|
||||
if torch.is_tensor(self._num_faces_per_mesh)
|
||||
else self._num_faces_per_mesh
|
||||
)
|
||||
tex.valid = self.valid.detach()
|
||||
tex._num_faces_per_mesh = num_faces
|
||||
return tex
|
||||
|
||||
def __getitem__(self, index):
|
||||
props = ["atlas_list", "_num_faces_per_mesh"]
|
||||
new_props = self._getitem(index, props=props)
|
||||
@@ -656,6 +678,12 @@ class TexturesUV(TexturesBase):
|
||||
self.faces_uvs_padded().clone(),
|
||||
self.verts_uvs_padded().clone(),
|
||||
)
|
||||
if self._maps_list is not None:
|
||||
tex._maps_list = [m.clone() for m in self._maps_list]
|
||||
if self._verts_uvs_list is not None:
|
||||
tex._verts_uvs_list = [v.clone() for v in self._verts_uvs_list]
|
||||
if self._faces_uvs_list is not None:
|
||||
tex._faces_uvs_list = [f.clone() for f in self._faces_uvs_list]
|
||||
num_faces = (
|
||||
self._num_faces_per_mesh.clone()
|
||||
if torch.is_tensor(self._num_faces_per_mesh)
|
||||
@@ -665,6 +693,27 @@ class TexturesUV(TexturesBase):
|
||||
tex.valid = self.valid.clone()
|
||||
return tex
|
||||
|
||||
def detach(self):
|
||||
tex = self.__class__(
|
||||
self.maps_padded().detach(),
|
||||
self.faces_uvs_padded().detach(),
|
||||
self.verts_uvs_padded().detach(),
|
||||
)
|
||||
if self._maps_list is not None:
|
||||
tex._maps_list = [m.detach() for m in self._maps_list]
|
||||
if self._verts_uvs_list is not None:
|
||||
tex._verts_uvs_list = [v.detach() for v in self._verts_uvs_list]
|
||||
if self._faces_uvs_list is not None:
|
||||
tex._faces_uvs_list = [f.detach() for f in self._faces_uvs_list]
|
||||
num_faces = (
|
||||
self._num_faces_per_mesh.detach()
|
||||
if torch.is_tensor(self._num_faces_per_mesh)
|
||||
else self._num_faces_per_mesh
|
||||
)
|
||||
tex._num_faces_per_mesh = num_faces
|
||||
tex.valid = self.valid.detach()
|
||||
return tex
|
||||
|
||||
def __getitem__(self, index):
|
||||
props = ["verts_uvs_list", "faces_uvs_list", "maps_list", "_num_faces_per_mesh"]
|
||||
new_props = self._getitem(index, props)
|
||||
@@ -892,8 +941,8 @@ class TexturesVertex(TexturesBase):
|
||||
has a D dimensional feature vector.
|
||||
|
||||
Args:
|
||||
verts_features: (N, V, D) tensor giving a feature vector with
|
||||
artbitrary dimensions for each vertex.
|
||||
verts_features: list of (Vi, D) or (N, V, D) tensor giving a feature
|
||||
vector with artbitrary dimensions for each vertex.
|
||||
"""
|
||||
if isinstance(verts_features, (tuple, list)):
|
||||
correct_shape = all(
|
||||
@@ -948,15 +997,28 @@ class TexturesVertex(TexturesBase):
|
||||
tex = self.__class__(self.verts_features_padded().clone())
|
||||
if self._verts_features_list is not None:
|
||||
tex._verts_features_list = [f.clone() for f in self._verts_features_list]
|
||||
num_faces = (
|
||||
num_verts = (
|
||||
self._num_verts_per_mesh.clone()
|
||||
if torch.is_tensor(self._num_verts_per_mesh)
|
||||
else self._num_verts_per_mesh
|
||||
)
|
||||
tex._num_verts_per_mesh = num_faces
|
||||
tex._num_verts_per_mesh = num_verts
|
||||
tex.valid = self.valid.clone()
|
||||
return tex
|
||||
|
||||
def detach(self):
|
||||
tex = self.__class__(self.verts_features_padded().detach())
|
||||
if self._verts_features_list is not None:
|
||||
tex._verts_features_list = [f.detach() for f in self._verts_features_list]
|
||||
num_verts = (
|
||||
self._num_verts_per_mesh.detach()
|
||||
if torch.is_tensor(self._num_verts_per_mesh)
|
||||
else self._num_verts_per_mesh
|
||||
)
|
||||
tex._num_verts_per_mesh = num_verts
|
||||
tex.valid = self.valid.detach()
|
||||
return tex
|
||||
|
||||
def __getitem__(self, index):
|
||||
props = ["verts_features_list", "_num_verts_per_mesh"]
|
||||
new_props = self._getitem(index, props)
|
||||
|
||||
@@ -1138,6 +1138,28 @@ class Meshes(object):
|
||||
other.textures = self.textures.clone()
|
||||
return other
|
||||
|
||||
def detach(self):
|
||||
"""
|
||||
Detach Meshes object. All internal tensors are detached individually.
|
||||
|
||||
Returns:
|
||||
new Meshes object.
|
||||
"""
|
||||
verts_list = self.verts_list()
|
||||
faces_list = self.faces_list()
|
||||
new_verts_list = [v.detach() for v in verts_list]
|
||||
new_faces_list = [f.detach() for f in faces_list]
|
||||
other = self.__class__(verts=new_verts_list, faces=new_faces_list)
|
||||
for k in self._INTERNAL_TENSORS:
|
||||
v = getattr(self, k)
|
||||
if torch.is_tensor(v):
|
||||
setattr(other, k, v.detach())
|
||||
|
||||
# Textures is not a tensor but has a detach method
|
||||
if self.textures is not None:
|
||||
other.textures = self.textures.detach()
|
||||
return other
|
||||
|
||||
def to(self, device, copy: bool = False):
|
||||
"""
|
||||
Match functionality of torch.Tensor.to()
|
||||
|
||||
@@ -655,6 +655,42 @@ class Pointclouds(object):
|
||||
setattr(other, k, v.clone())
|
||||
return other
|
||||
|
||||
def detach(self):
|
||||
"""
|
||||
Detach Pointclouds object. All internal tensors are detached
|
||||
individually.
|
||||
|
||||
Returns:
|
||||
new Pointclouds object.
|
||||
"""
|
||||
# instantiate new pointcloud with the representation which is not None
|
||||
# (either list or tensor) to save compute.
|
||||
new_points, new_normals, new_features = None, None, None
|
||||
if self._points_list is not None:
|
||||
new_points = [v.detach() for v in self.points_list()]
|
||||
normals_list = self.normals_list()
|
||||
features_list = self.features_list()
|
||||
if normals_list is not None:
|
||||
new_normals = [n.detach() for n in normals_list]
|
||||
if features_list is not None:
|
||||
new_features = [f.detach() for f in features_list]
|
||||
elif self._points_padded is not None:
|
||||
new_points = self.points_padded().detach()
|
||||
normals_padded = self.normals_padded()
|
||||
features_padded = self.features_padded()
|
||||
if normals_padded is not None:
|
||||
new_normals = self.normals_padded().detach()
|
||||
if features_padded is not None:
|
||||
new_features = self.features_padded().detach()
|
||||
other = self.__class__(
|
||||
points=new_points, normals=new_normals, features=new_features
|
||||
)
|
||||
for k in self._INTERNAL_TENSORS:
|
||||
v = getattr(self, k)
|
||||
if torch.is_tensor(v):
|
||||
setattr(other, k, v.detach())
|
||||
return other
|
||||
|
||||
def to(self, device, copy: bool = False):
|
||||
"""
|
||||
Match functionality of torch.Tensor.to()
|
||||
|
||||
Reference in New Issue
Block a user