mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
getitem for textures
Summary: Make Meshes.__getitem__ carry texture information to the new mesh. Reviewed By: gkioxari Differential Revision: D20283976 fbshipit-source-id: d9ee0580c11ac5b4384df9d8158a07e6eb8d00fe
This commit is contained in:
parent
5a1d7143d8
commit
fb97ab104e
@ -415,10 +415,12 @@ class Meshes(object):
|
||||
else:
|
||||
raise IndexError(index)
|
||||
|
||||
textures = None if self.textures is None else self.textures[index]
|
||||
|
||||
if torch.is_tensor(verts) and torch.is_tensor(faces):
|
||||
return Meshes(verts=[verts], faces=[faces])
|
||||
return Meshes(verts=[verts], faces=[faces], textures=textures)
|
||||
elif isinstance(verts, list) and isinstance(faces, list):
|
||||
return Meshes(verts=verts, faces=faces)
|
||||
return Meshes(verts=verts, faces=faces, textures=textures)
|
||||
else:
|
||||
raise ValueError("(verts, faces) not defined correctly")
|
||||
|
||||
|
@ -115,8 +115,14 @@ class Textures(object):
|
||||
self._verts_rgb_padded = verts_rgb
|
||||
self._maps_padded = maps
|
||||
self._num_faces_per_mesh = None
|
||||
self._set_num_faces_per_mesh()
|
||||
|
||||
def _set_num_faces_per_mesh(self) -> None:
|
||||
"""
|
||||
Determines and sets the number of textured faces for each mesh.
|
||||
"""
|
||||
if self._faces_uvs_padded is not None:
|
||||
faces_uvs = self._faces_uvs_padded
|
||||
self._num_faces_per_mesh = faces_uvs.gt(-1).all(-1).sum(-1).tolist()
|
||||
|
||||
def clone(self):
|
||||
@ -134,6 +140,18 @@ class Textures(object):
|
||||
setattr(self, k, v.to(device))
|
||||
return self
|
||||
|
||||
def __getitem__(self, index):
|
||||
other = Textures()
|
||||
for key in dir(self):
|
||||
value = getattr(self, key)
|
||||
if torch.is_tensor(value):
|
||||
if isinstance(index, int):
|
||||
setattr(other, key, value[index][None])
|
||||
else:
|
||||
setattr(other, key, value[index])
|
||||
other._set_num_faces_per_mesh()
|
||||
return other
|
||||
|
||||
def faces_uvs_padded(self) -> torch.Tensor:
|
||||
return self._faces_uvs_padded
|
||||
|
||||
|
@ -167,6 +167,55 @@ class TestTexturing(TestCaseMixin, unittest.TestCase):
|
||||
self.assertSeparate(tex._verts_uvs_padded, tex_cloned._verts_uvs_padded)
|
||||
self.assertSeparate(tex._maps_padded, tex_cloned._maps_padded)
|
||||
|
||||
def test_getitem(self):
|
||||
N = 5
|
||||
V = 20
|
||||
source = {
|
||||
"maps": torch.rand(size=(N, 16, 16, 3)),
|
||||
"faces_uvs": torch.randint(size=(N, 10, 3), low=0, high=V),
|
||||
"verts_uvs": torch.rand((N, V, 2)),
|
||||
}
|
||||
tex = Textures(
|
||||
maps=source["maps"],
|
||||
faces_uvs=source["faces_uvs"],
|
||||
verts_uvs=source["verts_uvs"],
|
||||
)
|
||||
|
||||
verts = torch.rand(size=(N, V, 3))
|
||||
faces = torch.randint(size=(N, 10, 3), high=V)
|
||||
meshes = Meshes(verts=verts, faces=faces, textures=tex)
|
||||
|
||||
def tryindex(index):
|
||||
tex2 = tex[index]
|
||||
meshes2 = meshes[index]
|
||||
tex_from_meshes = meshes2.textures
|
||||
for item in source:
|
||||
basic = source[item][index]
|
||||
from_texture = getattr(tex2, item + "_padded")()
|
||||
from_meshes = getattr(tex_from_meshes, item + "_padded")()
|
||||
if isinstance(index, int):
|
||||
basic = basic[None]
|
||||
self.assertClose(basic, from_texture)
|
||||
self.assertClose(basic, from_meshes)
|
||||
self.assertEqual(
|
||||
from_texture.ndim, getattr(tex, item + "_padded")().ndim
|
||||
)
|
||||
if item == "faces_uvs":
|
||||
faces_uvs_list = tex_from_meshes.faces_uvs_list()
|
||||
self.assertEqual(basic.shape[0], len(faces_uvs_list))
|
||||
for i, faces_uvs in enumerate(faces_uvs_list):
|
||||
self.assertClose(faces_uvs, basic[i])
|
||||
|
||||
tryindex(2)
|
||||
tryindex(slice(0, 2, 1))
|
||||
index = torch.tensor([1, 0, 1, 0, 0], dtype=torch.bool)
|
||||
tryindex(index)
|
||||
index = torch.tensor([0, 0, 0, 0, 0], dtype=torch.bool)
|
||||
tryindex(index)
|
||||
index = torch.tensor([1, 2], dtype=torch.int64)
|
||||
tryindex(index)
|
||||
tryindex([2, 4])
|
||||
|
||||
def test_to(self):
|
||||
V = 20
|
||||
tex = Textures(
|
||||
|
Loading…
x
Reference in New Issue
Block a user