diff --git a/pytorch3d/structures/meshes.py b/pytorch3d/structures/meshes.py index dffcb3c5..08f804a2 100644 --- a/pytorch3d/structures/meshes.py +++ b/pytorch3d/structures/meshes.py @@ -207,7 +207,14 @@ class Meshes(object): "equisized", ] - def __init__(self, verts=None, faces=None, textures=None): + def __init__( + self, + verts=None, + faces=None, + textures=None, + *, + verts_normals=None, + ): """ Args: verts: @@ -229,6 +236,17 @@ class Meshes(object): the same number of faces. textures: Optional instance of the Textures class with mesh texture properties. + verts_normals: + Optional. Can be either + + - List where each element is a tensor of shape (num_verts, 3) + containing the normals of each vertex. + - Padded float tensor with shape (num_meshes, max_num_verts, 3). + They should be padded with fill value of 0 so they all have + the same number of vertices. + Note that modifying the mesh later, e.g. with offset_verts_, + can cause these normals to be forgotten and normals to be recalculated + based on the new vertex positions. Refer to comments above for descriptions of List and Padded representations. """ @@ -354,8 +372,8 @@ class Meshes(object): self.equisized = True elif torch.is_tensor(verts) and torch.is_tensor(faces): - if verts.size(2) != 3 and faces.size(2) != 3: - raise ValueError("Verts and Faces tensors have incorrect dimensions.") + if verts.size(2) != 3 or faces.size(2) != 3: + raise ValueError("Verts or Faces tensors have incorrect dimensions.") self._verts_padded = verts self._faces_padded = faces.to(torch.int64) self._N = self._verts_padded.shape[0] @@ -412,6 +430,36 @@ class Meshes(object): self.textures._N = self._N self.textures.valid = self.valid + if verts_normals is not None: + self._set_verts_normals(verts_normals) + + def _set_verts_normals(self, verts_normals) -> None: + if isinstance(verts_normals, list): + if len(verts_normals) != self._N: + raise ValueError("Invalid verts_normals input") + + for item, n_verts in zip(verts_normals, self._num_verts_per_mesh): + if ( + not isinstance(item, torch.Tensor) + or item.ndim != 2 + or item.shape[1] != 3 + or item.shape[0] != n_verts + ): + raise ValueError("Invalid verts_normals input") + self._verts_normals_packed = torch.cat(verts_normals, 0) + elif torch.is_tensor(verts_normals): + if ( + verts_normals.ndim != 3 + or verts_normals.size(2) != 3 + or verts_normals.size(0) != self._N + ): + raise ValueError("Vertex normals tensor has incorrect dimensions.") + self._verts_normals_packed = struct_utils.padded_to_packed( + verts_normals, split_size=self._num_verts_per_mesh.tolist() + ) + else: + raise ValueError("verts_normals must be a list or tensor") + def __len__(self): return self._N @@ -1253,6 +1301,7 @@ class Meshes(object): def offset_verts_(self, vert_offsets_packed): """ Add an offset to the vertices of this Meshes. In place operation. + If normals are present they may be recalculated. Args: vert_offsets_packed: A Tensor of shape (3,) or the same shape as @@ -1286,7 +1335,7 @@ class Meshes(object): self._verts_padded[i, : verts.shape[0], :] = verts # update face areas and normals and vertex normals - # only if the original attributes are computed + # only if the original attributes are present if update_normals and any( v is not None for v in [self._faces_areas_packed, self._faces_normals_packed] diff --git a/tests/test_mesh_normal_consistency.py b/tests/test_mesh_normal_consistency.py index d597facd..3f2f1a93 100644 --- a/tests/test_mesh_normal_consistency.py +++ b/tests/test_mesh_normal_consistency.py @@ -223,7 +223,7 @@ class TestMeshNormalConsistency(unittest.TestCase): Test Mesh Normal Consistency for a mesh known to have no intersecting faces. """ - verts = torch.rand(1, 6, 2) + verts = torch.rand(1, 6, 3) faces = torch.arange(6).reshape(1, 2, 3) meshes = Meshes(verts=verts, faces=faces) out = mesh_normal_consistency(meshes) diff --git a/tests/test_meshes.py b/tests/test_meshes.py index 7dc7a15c..deee594f 100644 --- a/tests/test_meshes.py +++ b/tests/test_meshes.py @@ -1138,6 +1138,20 @@ class TestMeshes(TestCaseMixin, unittest.TestCase): self.assertEqual(meshes.faces_normals_padded().shape[0], 0) self.assertEqual(meshes.faces_normals_list(), []) + def test_assigned_normals(self): + verts = torch.rand(2, 6, 3) + faces = torch.randint(6, size=(2, 4, 3)) + + for verts_normals in [list(verts.unbind(0)), verts]: + yes_normals = Meshes( + verts=verts.clone(), faces=faces, verts_normals=verts_normals + ) + self.assertClose(yes_normals.verts_normals_padded(), verts) + yes_normals.offset_verts_(torch.FloatTensor([1, 2, 3])) + self.assertClose(yes_normals.verts_normals_padded(), verts) + yes_normals.offset_verts_(torch.FloatTensor([1, 2, 3]).expand(12, 3)) + self.assertFalse(torch.allclose(yes_normals.verts_normals_padded(), verts)) + def test_compute_faces_areas_cpu_cuda(self): num_meshes = 10 max_v = 100