fix subdivide_meshes with empty mesh #1788

Summary:
Simplify code

fixes https://github.com/facebookresearch/pytorch3d/issues/1788

Reviewed By: MichaelRamamonjisoa

Differential Revision: D61847675

fbshipit-source-id: 48400875d1d885bb3615bc9f4b3c7c3d822b67e7
This commit is contained in:
Jeremy Reizenstein 2024-11-06 11:40:26 -08:00 committed by Facebook GitHub Bot
parent c434957b2a
commit 8fe6934885
2 changed files with 21 additions and 77 deletions

View File

@ -353,45 +353,16 @@ def _create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
# e.g. verts_per_mesh = (4, 5, 6)
# e.g. edges_per_mesh = (5, 7, 9)
V = verts_per_mesh.sum() # e.g. 15
E = edges_per_mesh.sum() # e.g. 21
verts_per_mesh_cumsum = verts_per_mesh.cumsum(dim=0) # (N,) e.g. (4, 9, 15)
edges_per_mesh_cumsum = edges_per_mesh.cumsum(dim=0) # (N,) e.g. (5, 12, 21)
v_to_e_idx = verts_per_mesh_cumsum.clone()
# vertex to edge index.
v_to_e_idx[1:] += edges_per_mesh_cumsum[
:-1
] # e.g. (4, 9, 15) + (0, 5, 12) = (4, 14, 27)
# vertex to edge offset.
v_to_e_offset = V - verts_per_mesh_cumsum # e.g. 15 - (4, 9, 15) = (11, 6, 0)
v_to_e_offset[1:] += edges_per_mesh_cumsum[
:-1
] # e.g. (11, 6, 0) + (0, 5, 12) = (11, 11, 12)
e_to_v_idx = (
verts_per_mesh_cumsum[:-1] + edges_per_mesh_cumsum[:-1]
) # (4, 9) + (5, 12) = (9, 21)
e_to_v_offset = (
verts_per_mesh_cumsum[:-1] - edges_per_mesh_cumsum[:-1] - V
) # (4, 9) - (5, 12) - 15 = (-16, -18)
# Add one new vertex per edge.
idx_diffs = torch.ones(V + E, device=device, dtype=torch.int64) # (36,)
idx_diffs[v_to_e_idx] += v_to_e_offset
idx_diffs[e_to_v_idx] += e_to_v_offset
# e.g.
# [
# 1, 1, 1, 1, 12, 1, 1, 1, 1,
# -15, 1, 1, 1, 1, 12, 1, 1, 1, 1, 1, 1,
# -17, 1, 1, 1, 1, 1, 13, 1, 1, 1, 1, 1, 1, 1
# ]
verts_idx = idx_diffs.cumsum(dim=0) - 1
rng = torch.arange(verts_per_mesh.shape[0], device=device) # (0,1,2)
verts_nums = rng.repeat_interleave(
verts_per_mesh
) # (0,0,0,0,1,1,1,1,1,2,2,2,2,2,2)
edges_nums = rng.repeat_interleave(
edges_per_mesh
) # (0,0,0,0,0,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2)
nums = torch.cat([verts_nums, edges_nums])
verts_idx = torch.argsort(nums, stable=True)
# e.g.
# [
# 0, 1, 2, 3, 15, 16, 17, 18, 19, --> mesh 0
@ -400,7 +371,6 @@ def _create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
# ]
# where for mesh 0, [0, 1, 2, 3] are the indices of the existing verts, and
# [15, 16, 17, 18, 19] are the indices of the new verts after subdivision.
return verts_idx
@ -421,44 +391,9 @@ def _create_faces_index(faces_per_mesh: torch.Tensor, device=None):
"""
# e.g. faces_per_mesh = [2, 5, 3]
F = faces_per_mesh.sum() # e.g. 10
faces_per_mesh_cumsum = faces_per_mesh.cumsum(dim=0) # (N,) e.g. (2, 7, 10)
switch1_idx = faces_per_mesh_cumsum.clone()
switch1_idx[1:] += (
3 * faces_per_mesh_cumsum[:-1]
) # e.g. (2, 7, 10) + (0, 6, 21) = (2, 13, 31)
switch2_idx = 2 * faces_per_mesh_cumsum # e.g. (4, 14, 20)
switch2_idx[1:] += (
2 * faces_per_mesh_cumsum[:-1]
) # e.g. (4, 14, 20) + (0, 4, 14) = (4, 18, 34)
switch3_idx = 3 * faces_per_mesh_cumsum # e.g. (6, 21, 30)
switch3_idx[1:] += faces_per_mesh_cumsum[
:-1
] # e.g. (6, 21, 30) + (0, 2, 7) = (6, 23, 37)
switch4_idx = 4 * faces_per_mesh_cumsum[:-1] # e.g. (8, 28)
switch123_offset = F - faces_per_mesh # e.g. (8, 5, 7)
# pyre-fixme[6]: For 1st param expected `Union[List[int], Size,
# typing.Tuple[int, ...]]` but got `Tensor`.
idx_diffs = torch.ones(4 * F, device=device, dtype=torch.int64)
idx_diffs[switch1_idx] += switch123_offset
idx_diffs[switch2_idx] += switch123_offset
idx_diffs[switch3_idx] += switch123_offset
idx_diffs[switch4_idx] -= 3 * F
# e.g
# [
# 1, 1, 9, 1, 9, 1, 9, 1, -> mesh 0
# -29, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, -> mesh 1
# -29, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1 -> mesh 2
# ]
faces_idx = idx_diffs.cumsum(dim=0) - 1
rng = torch.arange(faces_per_mesh.shape[0], device=device) # (0,1,2)
nums = rng.repeat_interleave(faces_per_mesh).repeat(4)
faces_idx = torch.argsort(nums, stable=True)
# e.g.
# [

View File

@ -217,6 +217,15 @@ class TestSubdivideMeshes(TestCaseMixin, unittest.TestCase):
self.assertClose(new_feats, gt_feats)
self.assertTrue(new_feats.requires_grad == gt_feats.requires_grad)
def test_with_empty(self):
verts_list = [[[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]], []]
faces_list = [[[0, 1, 2], [0, 2, 3]], []]
verts_list = [torch.tensor(verts, dtype=torch.float64) for verts in verts_list]
face_list = [torch.tensor(faces, dtype=torch.long) for faces in faces_list]
meshes = Meshes(verts=verts_list, faces=face_list)
subdivided_meshes = SubdivideMeshes()(meshes)
self.assertEqual(len(subdivided_meshes), 2)
@staticmethod
def subdivide_meshes_with_init(num_meshes: int = 10, same_topo: bool = False):
device = torch.device("cuda:0")