update padded in meshes

Summary:
Three changes to Meshes

1. `num_verts_per_mesh` and `num_faces_per_mesh` are assigned at construction time and are returned without the need for `compute_packed`
2. `update_padded` updates `verts_padded` and shallow copies faces list and faces_padded and existing attributes from construction.
3. `padded_to_packed_idx` does not need `compute_packed`

Reviewed By: nikhilaravi

Differential Revision: D21653674

fbshipit-source-id: dc6815a2e2a925fe4a834fe357919da2b2c14527
This commit is contained in:
Georgia Gkioxari 2020-05-22 22:36:47 -07:00 committed by Facebook GitHub Bot
parent ae68a54f67
commit 1fb97f9c84
2 changed files with 201 additions and 21 deletions

View File

@ -462,9 +462,9 @@ class Meshes(object):
assert (
self._verts_padded is not None
), "verts_padded is required to compute verts_list."
self._verts_list = [
v[0] for v in self._verts_padded.split([1] * self._N, 0)
]
self._verts_list = struct_utils.padded_to_list(
self._verts_padded, self.num_verts_per_mesh().tolist()
)
return self._verts_list
def faces_list(self):
@ -478,10 +478,9 @@ class Meshes(object):
assert (
self._faces_padded is not None
), "faces_padded is required to compute faces_list."
self._faces_list = []
for i in range(self._N):
valid = self._faces_padded[i].gt(-1).all(1)
self._faces_list.append(self._faces_padded[i, valid, :])
self._faces_list = struct_utils.padded_to_list(
self._faces_padded, self.num_faces_per_mesh().tolist()
)
return self._faces_list
def verts_packed(self):
@ -525,7 +524,6 @@ class Meshes(object):
Returns:
1D tensor of sizes.
"""
self._compute_packed()
return self._num_verts_per_mesh
def faces_packed(self):
@ -590,7 +588,6 @@ class Meshes(object):
Returns:
1D tensor of sizes.
"""
self._compute_packed()
return self._num_faces_per_mesh
def edges_packed(self):
@ -664,14 +661,13 @@ class Meshes(object):
Returns:
1D tensor of indices.
"""
self._compute_packed()
if self._verts_padded_to_packed_idx is not None:
return self._verts_padded_to_packed_idx
self._verts_padded_to_packed_idx = torch.cat(
[
torch.arange(v, dtype=torch.int64, device=self.device) + i * self._V
for (i, v) in enumerate(self._num_verts_per_mesh)
for (i, v) in enumerate(self.num_verts_per_mesh())
],
dim=0,
)
@ -862,8 +858,8 @@ class Meshes(object):
):
return
verts_list = self._verts_list
faces_list = self._faces_list
verts_list = self.verts_list()
faces_list = self.faces_list()
assert (
faces_list is not None and verts_list is not None
), "faces_list and verts_list arguments are required"
@ -943,13 +939,15 @@ class Meshes(object):
verts_list_to_packed = struct_utils.list_to_packed(verts_list)
self._verts_packed = verts_list_to_packed[0]
self._num_verts_per_mesh = verts_list_to_packed[1]
if not torch.allclose(self.num_verts_per_mesh(), verts_list_to_packed[1]):
raise ValueError("The number of verts per mesh should be consistent.")
self._mesh_to_verts_packed_first_idx = verts_list_to_packed[2]
self._verts_packed_to_mesh_idx = verts_list_to_packed[3]
faces_list_to_packed = struct_utils.list_to_packed(faces_list)
faces_packed = faces_list_to_packed[0]
self._num_faces_per_mesh = faces_list_to_packed[1]
if not torch.allclose(self.num_faces_per_mesh(), faces_list_to_packed[1]):
raise ValueError("The number of faces per mesh should be consistent.")
self._mesh_to_faces_packed_first_idx = faces_list_to_packed[2]
self._faces_packed_to_mesh_idx = faces_list_to_packed[3]
@ -1328,6 +1326,100 @@ class Meshes(object):
new_mesh = self.clone()
return new_mesh.scale_verts_(scale)
def update_padded(self, new_verts_padded):
"""
This function allows for an pdate of verts_padded without having to
explicitly convert it to the list representation for heterogeneous batches.
Returns a Meshes structure with updated padded tensors and copies of the
auxiliary tensors at construction time.
It updates self._verts_padded with new_verts_padded, and does a
shallow copy of (faces_padded, faces_list, num_verts_per_mesh, num_faces_per_mesh).
If packed representations are computed in self, they are updated as well.
Args:
new_points_padded: FloatTensor of shape (N, V, 3)
Returns:
Meshes with updated padded representations
"""
def check_shapes(x, size):
if x.shape[0] != size[0]:
raise ValueError("new values must have the same batch dimension.")
if x.shape[1] != size[1]:
raise ValueError("new values must have the same number of points.")
if x.shape[2] != size[2]:
raise ValueError("new values must have the same dimension.")
check_shapes(new_verts_padded, [self._N, self._V, 3])
new = self.__class__(verts=new_verts_padded, faces=self.faces_padded())
if new._N != self._N or new._V != self._V or new._F != self._F:
raise ValueError("Inconsistent sizes after construction.")
# overwrite the equisized flag
new.equisized = self.equisized
# overwrite textures if any
new.textures = self.textures
# copy auxiliary tensors
copy_tensors = ["_num_verts_per_mesh", "_num_faces_per_mesh", "valid"]
for k in copy_tensors:
v = getattr(self, k)
if torch.is_tensor(v):
setattr(new, k, v) # shallow copy
# shallow copy of faces_list if any, st new.faces_list()
# does not re-compute from _faces_padded
new._faces_list = self._faces_list
# update verts/faces packed if they are computed in self
if self._verts_packed is not None:
copy_tensors = [
"_faces_packed",
"_verts_packed_to_mesh_idx",
"_faces_packed_to_mesh_idx",
"_mesh_to_verts_packed_first_idx",
"_mesh_to_faces_packed_first_idx",
]
for k in copy_tensors:
v = getattr(self, k)
assert torch.is_tensor(v)
setattr(new, k, v) # shallow copy
# update verts_packed
pad_to_packed = self.verts_padded_to_packed_idx()
new_verts_packed = new_verts_padded.reshape(-1, 3)[pad_to_packed, :]
new._verts_packed = new_verts_packed
new._verts_padded_to_packed_idx = pad_to_packed
# update edges packed if they are computed in self
if self._edges_packed is not None:
copy_tensors = [
"_edges_packed",
"_edges_packed_to_mesh_idx",
"_mesh_to_edges_packed_first_idx",
"_faces_packed_to_edges_packed",
"_num_edges_per_mesh",
]
for k in copy_tensors:
v = getattr(self, k)
assert torch.is_tensor(v)
setattr(new, k, v) # shallow copy
# update laplacian if it is compute in self
if self._laplacian_packed is not None:
new._laplacian_packed = self._laplacian_packed
assert new._verts_list is None
assert new._verts_normals_packed is None
assert new._faces_normals_packed is None
assert new._faces_areas_packed is None
return new
# TODO(nikhilar) Move function to utils file.
def get_bounding_boxes(self):
"""

View File

@ -342,16 +342,11 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
# Modify tensors in both meshes.
new_mesh._verts_list[0] = new_mesh._verts_list[0] * 5
mesh._num_verts_per_mesh = torch.randint_like(
mesh.num_verts_per_mesh(), high=10
)
# Check cloned and original Meshes objects do not share tensors.
self.assertFalse(
torch.allclose(new_mesh._verts_list[0], mesh._verts_list[0])
)
self.assertFalse(
torch.allclose(mesh.num_verts_per_mesh(), new_mesh.num_verts_per_mesh())
)
self.assertSeparate(new_mesh.verts_packed(), mesh.verts_packed())
self.assertSeparate(new_mesh.verts_padded(), mesh.verts_padded())
self.assertSeparate(new_mesh.faces_packed(), mesh.faces_packed())
@ -690,6 +685,99 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
with self.assertRaises(ValueError):
mesh.split(split_sizes)
def test_update_padded(self):
# Define the test mesh object either as a list or tensor of faces/verts.
N = 10
for lists_to_tensors in (False, True):
for force in (True, False):
mesh = TestMeshes.init_mesh(
N, 100, 300, lists_to_tensors=lists_to_tensors
)
num_verts_per_mesh = mesh.num_verts_per_mesh()
if force:
# force mesh to have computed attributes
mesh.verts_packed()
mesh.edges_packed()
mesh.laplacian_packed()
mesh.faces_areas_packed()
new_verts = torch.rand((mesh._N, mesh._V, 3), device=mesh.device)
new_verts_list = [
new_verts[i, : num_verts_per_mesh[i]] for i in range(N)
]
new_mesh = mesh.update_padded(new_verts)
# check the attributes assigned at construction time
self.assertEqual(new_mesh._N, mesh._N)
self.assertEqual(new_mesh._F, mesh._F)
self.assertEqual(new_mesh._V, mesh._V)
self.assertEqual(new_mesh.equisized, mesh.equisized)
self.assertTrue(all(new_mesh.valid == mesh.valid))
self.assertNotSeparate(
new_mesh.num_verts_per_mesh(), mesh.num_verts_per_mesh()
)
self.assertClose(
new_mesh.num_verts_per_mesh(), mesh.num_verts_per_mesh()
)
self.assertNotSeparate(
new_mesh.num_faces_per_mesh(), mesh.num_faces_per_mesh()
)
self.assertClose(
new_mesh.num_faces_per_mesh(), mesh.num_faces_per_mesh()
)
# check that the following attributes are not assigned
self.assertIsNone(new_mesh._verts_list)
self.assertIsNone(new_mesh._faces_areas_packed)
self.assertIsNone(new_mesh._faces_normals_packed)
self.assertIsNone(new_mesh._verts_normals_packed)
check_tensors = [
"_faces_packed",
"_verts_packed_to_mesh_idx",
"_faces_packed_to_mesh_idx",
"_mesh_to_verts_packed_first_idx",
"_mesh_to_faces_packed_first_idx",
"_edges_packed",
"_edges_packed_to_mesh_idx",
"_mesh_to_edges_packed_first_idx",
"_faces_packed_to_edges_packed",
"_num_edges_per_mesh",
]
for k in check_tensors:
v = getattr(new_mesh, k)
if not force:
self.assertIsNone(v)
else:
v_old = getattr(mesh, k)
self.assertNotSeparate(v, v_old)
self.assertClose(v, v_old)
# check verts/faces padded
self.assertClose(new_mesh.verts_padded(), new_verts)
self.assertNotSeparate(new_mesh.verts_padded(), new_verts)
self.assertClose(new_mesh.faces_padded(), mesh.faces_padded())
self.assertNotSeparate(new_mesh.faces_padded(), mesh.faces_padded())
# check verts/faces list
for i in range(N):
self.assertNotSeparate(
new_mesh.faces_list()[i], mesh.faces_list()[i]
)
self.assertClose(new_mesh.faces_list()[i], mesh.faces_list()[i])
self.assertSeparate(new_mesh.verts_list()[i], mesh.verts_list()[i])
self.assertClose(new_mesh.verts_list()[i], new_verts_list[i])
# check verts/faces packed
self.assertClose(new_mesh.verts_packed(), torch.cat(new_verts_list))
self.assertSeparate(new_mesh.verts_packed(), mesh.verts_packed())
self.assertClose(new_mesh.faces_packed(), mesh.faces_packed())
# check pad_to_packed
self.assertClose(
new_mesh.verts_padded_to_packed_idx(),
mesh.verts_padded_to_packed_idx(),
)
# check edges
self.assertClose(new_mesh.edges_packed(), mesh.edges_packed())
def test_get_mesh_verts_faces(self):
device = torch.device("cuda:0")
verts_list = []