getitem for textures

Summary: Make Meshes.__getitem__ carry texture information to the new mesh.

Reviewed By: gkioxari

Differential Revision: D20283976

fbshipit-source-id: d9ee0580c11ac5b4384df9d8158a07e6eb8d00fe
This commit is contained in:
Jeremy Reizenstein
2020-03-11 07:42:35 -07:00
committed by Facebook GitHub Bot
parent 5a1d7143d8
commit fb97ab104e
3 changed files with 71 additions and 2 deletions

View File

@@ -415,10 +415,12 @@ class Meshes(object):
else:
raise IndexError(index)
textures = None if self.textures is None else self.textures[index]
if torch.is_tensor(verts) and torch.is_tensor(faces):
return Meshes(verts=[verts], faces=[faces])
return Meshes(verts=[verts], faces=[faces], textures=textures)
elif isinstance(verts, list) and isinstance(faces, list):
return Meshes(verts=verts, faces=faces)
return Meshes(verts=verts, faces=faces, textures=textures)
else:
raise ValueError("(verts, faces) not defined correctly")

View File

@@ -115,8 +115,14 @@ class Textures(object):
self._verts_rgb_padded = verts_rgb
self._maps_padded = maps
self._num_faces_per_mesh = None
self._set_num_faces_per_mesh()
def _set_num_faces_per_mesh(self) -> None:
"""
Determines and sets the number of textured faces for each mesh.
"""
if self._faces_uvs_padded is not None:
faces_uvs = self._faces_uvs_padded
self._num_faces_per_mesh = faces_uvs.gt(-1).all(-1).sum(-1).tolist()
def clone(self):
@@ -134,6 +140,18 @@ class Textures(object):
setattr(self, k, v.to(device))
return self
def __getitem__(self, index):
other = Textures()
for key in dir(self):
value = getattr(self, key)
if torch.is_tensor(value):
if isinstance(index, int):
setattr(other, key, value[index][None])
else:
setattr(other, key, value[index])
other._set_num_faces_per_mesh()
return other
def faces_uvs_padded(self) -> torch.Tensor:
return self._faces_uvs_padded