TexturesVertex._num_verts_per_mesh deep copy (#623)

Summary:
When a list of Meshes is `join_batched()`, the `num_verts_per_mesh` in the list would be unexpectedly modified.

Also some cleanup around `_num_verts_per_mesh`.

Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/623

Test Plan: A modification to an existing test checks this.

Reviewed By: nikhilaravi

Differential Revision: D27682104

Pulled By: bottler

fbshipit-source-id: 9d00913dfb4869bd6c7d3f5cc9156b7b6f1aecc9
This commit is contained in:
JudyYe 2021-04-20 03:10:39 -07:00 committed by Facebook GitHub Bot
parent 8660db9806
commit eb04a488c5
2 changed files with 11 additions and 17 deletions

View File

@ -842,8 +842,7 @@ class TexturesUV(TexturesBase):
else:
# The number of vertices in the mesh and in verts_uvs can differ
# e.g. if a vertex is shared between 3 faces, it can
# have up to 3 different uv coordinates. Therefore we cannot
# convert directly from padded to list using _num_verts_per_mesh
# have up to 3 different uv coordinates.
self._verts_uvs_list = list(self._verts_uvs_padded.unbind(0))
return self._verts_uvs_list
@ -1283,12 +1282,7 @@ class TexturesVertex(TexturesBase):
tex = self.__class__(self.verts_features_padded().clone())
if self._verts_features_list is not None:
tex._verts_features_list = [f.clone() for f in self._verts_features_list]
num_verts = (
self._num_verts_per_mesh.clone()
if torch.is_tensor(self._num_verts_per_mesh)
else self._num_verts_per_mesh
)
tex._num_verts_per_mesh = num_verts
tex._num_verts_per_mesh = self._num_verts_per_mesh.copy()
tex.valid = self.valid.clone()
return tex
@ -1296,12 +1290,7 @@ class TexturesVertex(TexturesBase):
tex = self.__class__(self.verts_features_padded().detach())
if self._verts_features_list is not None:
tex._verts_features_list = [f.detach() for f in self._verts_features_list]
num_verts = (
self._num_verts_per_mesh.detach()
if torch.is_tensor(self._num_verts_per_mesh)
else self._num_verts_per_mesh
)
tex._num_verts_per_mesh = num_verts
tex._num_verts_per_mesh = self._num_verts_per_mesh.copy()
tex.valid = self.valid.detach()
return tex
@ -1414,13 +1403,13 @@ class TexturesVertex(TexturesBase):
verts_features_list = []
verts_features_list += self.verts_features_list()
num_faces_per_mesh = self._num_verts_per_mesh
num_verts_per_mesh = self._num_verts_per_mesh.copy()
for tex in textures:
verts_features_list += tex.verts_features_list()
num_faces_per_mesh += tex._num_verts_per_mesh
num_verts_per_mesh += tex._num_verts_per_mesh
new_tex = self.__class__(verts_features=verts_features_list)
new_tex._num_verts_per_mesh = num_faces_per_mesh
new_tex._num_verts_per_mesh = num_verts_per_mesh
return new_tex
def join_scene(self) -> "TexturesVertex":

View File

@ -786,6 +786,11 @@ class TestMeshObjIO(TestCaseMixin, unittest.TestCase):
mesh_rgb = Meshes(verts=[verts], faces=[faces], textures=rgb_tex)
mesh_rgb3 = join_meshes_as_batch([mesh_rgb, mesh_rgb, mesh_rgb])
check_triple(mesh_rgb, mesh_rgb3)
nums_rgb = mesh_rgb.textures._num_verts_per_mesh
nums_rgb3 = mesh_rgb3.textures._num_verts_per_mesh
self.assertEqual(type(nums_rgb), list)
self.assertEqual(type(nums_rgb3), list)
self.assertListEqual(nums_rgb * 3, nums_rgb3)
# meshes with texture atlas, join into a batch.
device = "cuda:0"