Allow setting verts_normals on Meshes

Summary: Add ability to set the vertex normals when creating a Meshes, so that the pluggable loaders can return them from a file.

Reviewed By: nikhilaravi

Differential Revision: D27765258

fbshipit-source-id: b5ddaa00de3707f636f94d9f74d1da12ecce0608
This commit is contained in:
Jeremy Reizenstein 2021-05-04 05:35:24 -07:00 committed by Facebook GitHub Bot
parent 502f15aca7
commit 2bbca5f2a7
3 changed files with 68 additions and 5 deletions

View File

@ -207,7 +207,14 @@ class Meshes(object):
"equisized",
]
def __init__(self, verts=None, faces=None, textures=None):
def __init__(
self,
verts=None,
faces=None,
textures=None,
*,
verts_normals=None,
):
"""
Args:
verts:
@ -229,6 +236,17 @@ class Meshes(object):
the same number of faces.
textures: Optional instance of the Textures class with mesh
texture properties.
verts_normals:
Optional. Can be either
- List where each element is a tensor of shape (num_verts, 3)
containing the normals of each vertex.
- Padded float tensor with shape (num_meshes, max_num_verts, 3).
They should be padded with fill value of 0 so they all have
the same number of vertices.
Note that modifying the mesh later, e.g. with offset_verts_,
can cause these normals to be forgotten and normals to be recalculated
based on the new vertex positions.
Refer to comments above for descriptions of List and Padded representations.
"""
@ -354,8 +372,8 @@ class Meshes(object):
self.equisized = True
elif torch.is_tensor(verts) and torch.is_tensor(faces):
if verts.size(2) != 3 and faces.size(2) != 3:
raise ValueError("Verts and Faces tensors have incorrect dimensions.")
if verts.size(2) != 3 or faces.size(2) != 3:
raise ValueError("Verts or Faces tensors have incorrect dimensions.")
self._verts_padded = verts
self._faces_padded = faces.to(torch.int64)
self._N = self._verts_padded.shape[0]
@ -412,6 +430,36 @@ class Meshes(object):
self.textures._N = self._N
self.textures.valid = self.valid
if verts_normals is not None:
self._set_verts_normals(verts_normals)
def _set_verts_normals(self, verts_normals) -> None:
if isinstance(verts_normals, list):
if len(verts_normals) != self._N:
raise ValueError("Invalid verts_normals input")
for item, n_verts in zip(verts_normals, self._num_verts_per_mesh):
if (
not isinstance(item, torch.Tensor)
or item.ndim != 2
or item.shape[1] != 3
or item.shape[0] != n_verts
):
raise ValueError("Invalid verts_normals input")
self._verts_normals_packed = torch.cat(verts_normals, 0)
elif torch.is_tensor(verts_normals):
if (
verts_normals.ndim != 3
or verts_normals.size(2) != 3
or verts_normals.size(0) != self._N
):
raise ValueError("Vertex normals tensor has incorrect dimensions.")
self._verts_normals_packed = struct_utils.padded_to_packed(
verts_normals, split_size=self._num_verts_per_mesh.tolist()
)
else:
raise ValueError("verts_normals must be a list or tensor")
def __len__(self):
return self._N
@ -1253,6 +1301,7 @@ class Meshes(object):
def offset_verts_(self, vert_offsets_packed):
"""
Add an offset to the vertices of this Meshes. In place operation.
If normals are present they may be recalculated.
Args:
vert_offsets_packed: A Tensor of shape (3,) or the same shape as
@ -1286,7 +1335,7 @@ class Meshes(object):
self._verts_padded[i, : verts.shape[0], :] = verts
# update face areas and normals and vertex normals
# only if the original attributes are computed
# only if the original attributes are present
if update_normals and any(
v is not None
for v in [self._faces_areas_packed, self._faces_normals_packed]

View File

@ -223,7 +223,7 @@ class TestMeshNormalConsistency(unittest.TestCase):
Test Mesh Normal Consistency for a mesh known to have no
intersecting faces.
"""
verts = torch.rand(1, 6, 2)
verts = torch.rand(1, 6, 3)
faces = torch.arange(6).reshape(1, 2, 3)
meshes = Meshes(verts=verts, faces=faces)
out = mesh_normal_consistency(meshes)

View File

@ -1138,6 +1138,20 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
self.assertEqual(meshes.faces_normals_padded().shape[0], 0)
self.assertEqual(meshes.faces_normals_list(), [])
def test_assigned_normals(self):
verts = torch.rand(2, 6, 3)
faces = torch.randint(6, size=(2, 4, 3))
for verts_normals in [list(verts.unbind(0)), verts]:
yes_normals = Meshes(
verts=verts.clone(), faces=faces, verts_normals=verts_normals
)
self.assertClose(yes_normals.verts_normals_padded(), verts)
yes_normals.offset_verts_(torch.FloatTensor([1, 2, 3]))
self.assertClose(yes_normals.verts_normals_padded(), verts)
yes_normals.offset_verts_(torch.FloatTensor([1, 2, 3]).expand(12, 3))
self.assertFalse(torch.allclose(yes_normals.verts_normals_padded(), verts))
def test_compute_faces_areas_cpu_cuda(self):
num_meshes = 10
max_v = 100