Allow setting verts_normals on Meshes

Summary: Add ability to set the vertex normals when creating a Meshes, so that the pluggable loaders can return them from a file.

Reviewed By: nikhilaravi

Differential Revision: D27765258

fbshipit-source-id: b5ddaa00de3707f636f94d9f74d1da12ecce0608
This commit is contained in:
Jeremy Reizenstein 2021-05-04 05:35:24 -07:00 committed by Facebook GitHub Bot
parent 502f15aca7
commit 2bbca5f2a7
3 changed files with 68 additions and 5 deletions

View File

@ -207,7 +207,14 @@ class Meshes(object):
"equisized", "equisized",
] ]
def __init__(self, verts=None, faces=None, textures=None): def __init__(
self,
verts=None,
faces=None,
textures=None,
*,
verts_normals=None,
):
""" """
Args: Args:
verts: verts:
@ -229,6 +236,17 @@ class Meshes(object):
the same number of faces. the same number of faces.
textures: Optional instance of the Textures class with mesh textures: Optional instance of the Textures class with mesh
texture properties. texture properties.
verts_normals:
Optional. Can be either
- List where each element is a tensor of shape (num_verts, 3)
containing the normals of each vertex.
- Padded float tensor with shape (num_meshes, max_num_verts, 3).
They should be padded with fill value of 0 so they all have
the same number of vertices.
Note that modifying the mesh later, e.g. with offset_verts_,
can cause these normals to be forgotten and normals to be recalculated
based on the new vertex positions.
Refer to comments above for descriptions of List and Padded representations. Refer to comments above for descriptions of List and Padded representations.
""" """
@ -354,8 +372,8 @@ class Meshes(object):
self.equisized = True self.equisized = True
elif torch.is_tensor(verts) and torch.is_tensor(faces): elif torch.is_tensor(verts) and torch.is_tensor(faces):
if verts.size(2) != 3 and faces.size(2) != 3: if verts.size(2) != 3 or faces.size(2) != 3:
raise ValueError("Verts and Faces tensors have incorrect dimensions.") raise ValueError("Verts or Faces tensors have incorrect dimensions.")
self._verts_padded = verts self._verts_padded = verts
self._faces_padded = faces.to(torch.int64) self._faces_padded = faces.to(torch.int64)
self._N = self._verts_padded.shape[0] self._N = self._verts_padded.shape[0]
@ -412,6 +430,36 @@ class Meshes(object):
self.textures._N = self._N self.textures._N = self._N
self.textures.valid = self.valid self.textures.valid = self.valid
if verts_normals is not None:
self._set_verts_normals(verts_normals)
def _set_verts_normals(self, verts_normals) -> None:
if isinstance(verts_normals, list):
if len(verts_normals) != self._N:
raise ValueError("Invalid verts_normals input")
for item, n_verts in zip(verts_normals, self._num_verts_per_mesh):
if (
not isinstance(item, torch.Tensor)
or item.ndim != 2
or item.shape[1] != 3
or item.shape[0] != n_verts
):
raise ValueError("Invalid verts_normals input")
self._verts_normals_packed = torch.cat(verts_normals, 0)
elif torch.is_tensor(verts_normals):
if (
verts_normals.ndim != 3
or verts_normals.size(2) != 3
or verts_normals.size(0) != self._N
):
raise ValueError("Vertex normals tensor has incorrect dimensions.")
self._verts_normals_packed = struct_utils.padded_to_packed(
verts_normals, split_size=self._num_verts_per_mesh.tolist()
)
else:
raise ValueError("verts_normals must be a list or tensor")
def __len__(self): def __len__(self):
return self._N return self._N
@ -1253,6 +1301,7 @@ class Meshes(object):
def offset_verts_(self, vert_offsets_packed): def offset_verts_(self, vert_offsets_packed):
""" """
Add an offset to the vertices of this Meshes. In place operation. Add an offset to the vertices of this Meshes. In place operation.
If normals are present they may be recalculated.
Args: Args:
vert_offsets_packed: A Tensor of shape (3,) or the same shape as vert_offsets_packed: A Tensor of shape (3,) or the same shape as
@ -1286,7 +1335,7 @@ class Meshes(object):
self._verts_padded[i, : verts.shape[0], :] = verts self._verts_padded[i, : verts.shape[0], :] = verts
# update face areas and normals and vertex normals # update face areas and normals and vertex normals
# only if the original attributes are computed # only if the original attributes are present
if update_normals and any( if update_normals and any(
v is not None v is not None
for v in [self._faces_areas_packed, self._faces_normals_packed] for v in [self._faces_areas_packed, self._faces_normals_packed]

View File

@ -223,7 +223,7 @@ class TestMeshNormalConsistency(unittest.TestCase):
Test Mesh Normal Consistency for a mesh known to have no Test Mesh Normal Consistency for a mesh known to have no
intersecting faces. intersecting faces.
""" """
verts = torch.rand(1, 6, 2) verts = torch.rand(1, 6, 3)
faces = torch.arange(6).reshape(1, 2, 3) faces = torch.arange(6).reshape(1, 2, 3)
meshes = Meshes(verts=verts, faces=faces) meshes = Meshes(verts=verts, faces=faces)
out = mesh_normal_consistency(meshes) out = mesh_normal_consistency(meshes)

View File

@ -1138,6 +1138,20 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
self.assertEqual(meshes.faces_normals_padded().shape[0], 0) self.assertEqual(meshes.faces_normals_padded().shape[0], 0)
self.assertEqual(meshes.faces_normals_list(), []) self.assertEqual(meshes.faces_normals_list(), [])
def test_assigned_normals(self):
verts = torch.rand(2, 6, 3)
faces = torch.randint(6, size=(2, 4, 3))
for verts_normals in [list(verts.unbind(0)), verts]:
yes_normals = Meshes(
verts=verts.clone(), faces=faces, verts_normals=verts_normals
)
self.assertClose(yes_normals.verts_normals_padded(), verts)
yes_normals.offset_verts_(torch.FloatTensor([1, 2, 3]))
self.assertClose(yes_normals.verts_normals_padded(), verts)
yes_normals.offset_verts_(torch.FloatTensor([1, 2, 3]).expand(12, 3))
self.assertFalse(torch.allclose(yes_normals.verts_normals_padded(), verts))
def test_compute_faces_areas_cpu_cuda(self): def test_compute_faces_areas_cpu_cuda(self):
num_meshes = 10 num_meshes = 10
max_v = 100 max_v = 100