mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
update padded in meshes
Summary: Three changes to Meshes 1. `num_verts_per_mesh` and `num_faces_per_mesh` are assigned at construction time and are returned without the need for `compute_packed` 2. `update_padded` updates `verts_padded` and shallow copies faces list and faces_padded and existing attributes from construction. 3. `padded_to_packed_idx` does not need `compute_packed` Reviewed By: nikhilaravi Differential Revision: D21653674 fbshipit-source-id: dc6815a2e2a925fe4a834fe357919da2b2c14527
This commit is contained in:
parent
ae68a54f67
commit
1fb97f9c84
@ -462,9 +462,9 @@ class Meshes(object):
|
||||
assert (
|
||||
self._verts_padded is not None
|
||||
), "verts_padded is required to compute verts_list."
|
||||
self._verts_list = [
|
||||
v[0] for v in self._verts_padded.split([1] * self._N, 0)
|
||||
]
|
||||
self._verts_list = struct_utils.padded_to_list(
|
||||
self._verts_padded, self.num_verts_per_mesh().tolist()
|
||||
)
|
||||
return self._verts_list
|
||||
|
||||
def faces_list(self):
|
||||
@ -478,10 +478,9 @@ class Meshes(object):
|
||||
assert (
|
||||
self._faces_padded is not None
|
||||
), "faces_padded is required to compute faces_list."
|
||||
self._faces_list = []
|
||||
for i in range(self._N):
|
||||
valid = self._faces_padded[i].gt(-1).all(1)
|
||||
self._faces_list.append(self._faces_padded[i, valid, :])
|
||||
self._faces_list = struct_utils.padded_to_list(
|
||||
self._faces_padded, self.num_faces_per_mesh().tolist()
|
||||
)
|
||||
return self._faces_list
|
||||
|
||||
def verts_packed(self):
|
||||
@ -525,7 +524,6 @@ class Meshes(object):
|
||||
Returns:
|
||||
1D tensor of sizes.
|
||||
"""
|
||||
self._compute_packed()
|
||||
return self._num_verts_per_mesh
|
||||
|
||||
def faces_packed(self):
|
||||
@ -590,7 +588,6 @@ class Meshes(object):
|
||||
Returns:
|
||||
1D tensor of sizes.
|
||||
"""
|
||||
self._compute_packed()
|
||||
return self._num_faces_per_mesh
|
||||
|
||||
def edges_packed(self):
|
||||
@ -664,14 +661,13 @@ class Meshes(object):
|
||||
Returns:
|
||||
1D tensor of indices.
|
||||
"""
|
||||
self._compute_packed()
|
||||
if self._verts_padded_to_packed_idx is not None:
|
||||
return self._verts_padded_to_packed_idx
|
||||
|
||||
self._verts_padded_to_packed_idx = torch.cat(
|
||||
[
|
||||
torch.arange(v, dtype=torch.int64, device=self.device) + i * self._V
|
||||
for (i, v) in enumerate(self._num_verts_per_mesh)
|
||||
for (i, v) in enumerate(self.num_verts_per_mesh())
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
@ -862,8 +858,8 @@ class Meshes(object):
|
||||
):
|
||||
return
|
||||
|
||||
verts_list = self._verts_list
|
||||
faces_list = self._faces_list
|
||||
verts_list = self.verts_list()
|
||||
faces_list = self.faces_list()
|
||||
assert (
|
||||
faces_list is not None and verts_list is not None
|
||||
), "faces_list and verts_list arguments are required"
|
||||
@ -943,13 +939,15 @@ class Meshes(object):
|
||||
|
||||
verts_list_to_packed = struct_utils.list_to_packed(verts_list)
|
||||
self._verts_packed = verts_list_to_packed[0]
|
||||
self._num_verts_per_mesh = verts_list_to_packed[1]
|
||||
if not torch.allclose(self.num_verts_per_mesh(), verts_list_to_packed[1]):
|
||||
raise ValueError("The number of verts per mesh should be consistent.")
|
||||
self._mesh_to_verts_packed_first_idx = verts_list_to_packed[2]
|
||||
self._verts_packed_to_mesh_idx = verts_list_to_packed[3]
|
||||
|
||||
faces_list_to_packed = struct_utils.list_to_packed(faces_list)
|
||||
faces_packed = faces_list_to_packed[0]
|
||||
self._num_faces_per_mesh = faces_list_to_packed[1]
|
||||
if not torch.allclose(self.num_faces_per_mesh(), faces_list_to_packed[1]):
|
||||
raise ValueError("The number of faces per mesh should be consistent.")
|
||||
self._mesh_to_faces_packed_first_idx = faces_list_to_packed[2]
|
||||
self._faces_packed_to_mesh_idx = faces_list_to_packed[3]
|
||||
|
||||
@ -1328,6 +1326,100 @@ class Meshes(object):
|
||||
new_mesh = self.clone()
|
||||
return new_mesh.scale_verts_(scale)
|
||||
|
||||
def update_padded(self, new_verts_padded):
|
||||
"""
|
||||
This function allows for an pdate of verts_padded without having to
|
||||
explicitly convert it to the list representation for heterogeneous batches.
|
||||
Returns a Meshes structure with updated padded tensors and copies of the
|
||||
auxiliary tensors at construction time.
|
||||
It updates self._verts_padded with new_verts_padded, and does a
|
||||
shallow copy of (faces_padded, faces_list, num_verts_per_mesh, num_faces_per_mesh).
|
||||
If packed representations are computed in self, they are updated as well.
|
||||
|
||||
Args:
|
||||
new_points_padded: FloatTensor of shape (N, V, 3)
|
||||
|
||||
Returns:
|
||||
Meshes with updated padded representations
|
||||
"""
|
||||
|
||||
def check_shapes(x, size):
|
||||
if x.shape[0] != size[0]:
|
||||
raise ValueError("new values must have the same batch dimension.")
|
||||
if x.shape[1] != size[1]:
|
||||
raise ValueError("new values must have the same number of points.")
|
||||
if x.shape[2] != size[2]:
|
||||
raise ValueError("new values must have the same dimension.")
|
||||
|
||||
check_shapes(new_verts_padded, [self._N, self._V, 3])
|
||||
|
||||
new = self.__class__(verts=new_verts_padded, faces=self.faces_padded())
|
||||
|
||||
if new._N != self._N or new._V != self._V or new._F != self._F:
|
||||
raise ValueError("Inconsistent sizes after construction.")
|
||||
|
||||
# overwrite the equisized flag
|
||||
new.equisized = self.equisized
|
||||
|
||||
# overwrite textures if any
|
||||
new.textures = self.textures
|
||||
|
||||
# copy auxiliary tensors
|
||||
copy_tensors = ["_num_verts_per_mesh", "_num_faces_per_mesh", "valid"]
|
||||
|
||||
for k in copy_tensors:
|
||||
v = getattr(self, k)
|
||||
if torch.is_tensor(v):
|
||||
setattr(new, k, v) # shallow copy
|
||||
|
||||
# shallow copy of faces_list if any, st new.faces_list()
|
||||
# does not re-compute from _faces_padded
|
||||
new._faces_list = self._faces_list
|
||||
|
||||
# update verts/faces packed if they are computed in self
|
||||
if self._verts_packed is not None:
|
||||
copy_tensors = [
|
||||
"_faces_packed",
|
||||
"_verts_packed_to_mesh_idx",
|
||||
"_faces_packed_to_mesh_idx",
|
||||
"_mesh_to_verts_packed_first_idx",
|
||||
"_mesh_to_faces_packed_first_idx",
|
||||
]
|
||||
for k in copy_tensors:
|
||||
v = getattr(self, k)
|
||||
assert torch.is_tensor(v)
|
||||
setattr(new, k, v) # shallow copy
|
||||
# update verts_packed
|
||||
pad_to_packed = self.verts_padded_to_packed_idx()
|
||||
new_verts_packed = new_verts_padded.reshape(-1, 3)[pad_to_packed, :]
|
||||
new._verts_packed = new_verts_packed
|
||||
new._verts_padded_to_packed_idx = pad_to_packed
|
||||
|
||||
# update edges packed if they are computed in self
|
||||
if self._edges_packed is not None:
|
||||
copy_tensors = [
|
||||
"_edges_packed",
|
||||
"_edges_packed_to_mesh_idx",
|
||||
"_mesh_to_edges_packed_first_idx",
|
||||
"_faces_packed_to_edges_packed",
|
||||
"_num_edges_per_mesh",
|
||||
]
|
||||
for k in copy_tensors:
|
||||
v = getattr(self, k)
|
||||
assert torch.is_tensor(v)
|
||||
setattr(new, k, v) # shallow copy
|
||||
|
||||
# update laplacian if it is compute in self
|
||||
if self._laplacian_packed is not None:
|
||||
new._laplacian_packed = self._laplacian_packed
|
||||
|
||||
assert new._verts_list is None
|
||||
assert new._verts_normals_packed is None
|
||||
assert new._faces_normals_packed is None
|
||||
assert new._faces_areas_packed is None
|
||||
|
||||
return new
|
||||
|
||||
# TODO(nikhilar) Move function to utils file.
|
||||
def get_bounding_boxes(self):
|
||||
"""
|
||||
|
@ -342,16 +342,11 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
|
||||
|
||||
# Modify tensors in both meshes.
|
||||
new_mesh._verts_list[0] = new_mesh._verts_list[0] * 5
|
||||
mesh._num_verts_per_mesh = torch.randint_like(
|
||||
mesh.num_verts_per_mesh(), high=10
|
||||
)
|
||||
|
||||
# Check cloned and original Meshes objects do not share tensors.
|
||||
self.assertFalse(
|
||||
torch.allclose(new_mesh._verts_list[0], mesh._verts_list[0])
|
||||
)
|
||||
self.assertFalse(
|
||||
torch.allclose(mesh.num_verts_per_mesh(), new_mesh.num_verts_per_mesh())
|
||||
)
|
||||
self.assertSeparate(new_mesh.verts_packed(), mesh.verts_packed())
|
||||
self.assertSeparate(new_mesh.verts_padded(), mesh.verts_padded())
|
||||
self.assertSeparate(new_mesh.faces_packed(), mesh.faces_packed())
|
||||
@ -690,6 +685,99 @@ class TestMeshes(TestCaseMixin, unittest.TestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
mesh.split(split_sizes)
|
||||
|
||||
def test_update_padded(self):
|
||||
# Define the test mesh object either as a list or tensor of faces/verts.
|
||||
N = 10
|
||||
for lists_to_tensors in (False, True):
|
||||
for force in (True, False):
|
||||
mesh = TestMeshes.init_mesh(
|
||||
N, 100, 300, lists_to_tensors=lists_to_tensors
|
||||
)
|
||||
num_verts_per_mesh = mesh.num_verts_per_mesh()
|
||||
if force:
|
||||
# force mesh to have computed attributes
|
||||
mesh.verts_packed()
|
||||
mesh.edges_packed()
|
||||
mesh.laplacian_packed()
|
||||
mesh.faces_areas_packed()
|
||||
|
||||
new_verts = torch.rand((mesh._N, mesh._V, 3), device=mesh.device)
|
||||
new_verts_list = [
|
||||
new_verts[i, : num_verts_per_mesh[i]] for i in range(N)
|
||||
]
|
||||
new_mesh = mesh.update_padded(new_verts)
|
||||
|
||||
# check the attributes assigned at construction time
|
||||
self.assertEqual(new_mesh._N, mesh._N)
|
||||
self.assertEqual(new_mesh._F, mesh._F)
|
||||
self.assertEqual(new_mesh._V, mesh._V)
|
||||
self.assertEqual(new_mesh.equisized, mesh.equisized)
|
||||
self.assertTrue(all(new_mesh.valid == mesh.valid))
|
||||
self.assertNotSeparate(
|
||||
new_mesh.num_verts_per_mesh(), mesh.num_verts_per_mesh()
|
||||
)
|
||||
self.assertClose(
|
||||
new_mesh.num_verts_per_mesh(), mesh.num_verts_per_mesh()
|
||||
)
|
||||
self.assertNotSeparate(
|
||||
new_mesh.num_faces_per_mesh(), mesh.num_faces_per_mesh()
|
||||
)
|
||||
self.assertClose(
|
||||
new_mesh.num_faces_per_mesh(), mesh.num_faces_per_mesh()
|
||||
)
|
||||
|
||||
# check that the following attributes are not assigned
|
||||
self.assertIsNone(new_mesh._verts_list)
|
||||
self.assertIsNone(new_mesh._faces_areas_packed)
|
||||
self.assertIsNone(new_mesh._faces_normals_packed)
|
||||
self.assertIsNone(new_mesh._verts_normals_packed)
|
||||
|
||||
check_tensors = [
|
||||
"_faces_packed",
|
||||
"_verts_packed_to_mesh_idx",
|
||||
"_faces_packed_to_mesh_idx",
|
||||
"_mesh_to_verts_packed_first_idx",
|
||||
"_mesh_to_faces_packed_first_idx",
|
||||
"_edges_packed",
|
||||
"_edges_packed_to_mesh_idx",
|
||||
"_mesh_to_edges_packed_first_idx",
|
||||
"_faces_packed_to_edges_packed",
|
||||
"_num_edges_per_mesh",
|
||||
]
|
||||
for k in check_tensors:
|
||||
v = getattr(new_mesh, k)
|
||||
if not force:
|
||||
self.assertIsNone(v)
|
||||
else:
|
||||
v_old = getattr(mesh, k)
|
||||
self.assertNotSeparate(v, v_old)
|
||||
self.assertClose(v, v_old)
|
||||
|
||||
# check verts/faces padded
|
||||
self.assertClose(new_mesh.verts_padded(), new_verts)
|
||||
self.assertNotSeparate(new_mesh.verts_padded(), new_verts)
|
||||
self.assertClose(new_mesh.faces_padded(), mesh.faces_padded())
|
||||
self.assertNotSeparate(new_mesh.faces_padded(), mesh.faces_padded())
|
||||
# check verts/faces list
|
||||
for i in range(N):
|
||||
self.assertNotSeparate(
|
||||
new_mesh.faces_list()[i], mesh.faces_list()[i]
|
||||
)
|
||||
self.assertClose(new_mesh.faces_list()[i], mesh.faces_list()[i])
|
||||
self.assertSeparate(new_mesh.verts_list()[i], mesh.verts_list()[i])
|
||||
self.assertClose(new_mesh.verts_list()[i], new_verts_list[i])
|
||||
# check verts/faces packed
|
||||
self.assertClose(new_mesh.verts_packed(), torch.cat(new_verts_list))
|
||||
self.assertSeparate(new_mesh.verts_packed(), mesh.verts_packed())
|
||||
self.assertClose(new_mesh.faces_packed(), mesh.faces_packed())
|
||||
# check pad_to_packed
|
||||
self.assertClose(
|
||||
new_mesh.verts_padded_to_packed_idx(),
|
||||
mesh.verts_padded_to_packed_idx(),
|
||||
)
|
||||
# check edges
|
||||
self.assertClose(new_mesh.edges_packed(), mesh.edges_packed())
|
||||
|
||||
def test_get_mesh_verts_faces(self):
|
||||
device = torch.device("cuda:0")
|
||||
verts_list = []
|
||||
|
Loading…
x
Reference in New Issue
Block a user