mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
fix subdivide_meshes with empty mesh #1788
Summary: Simplify code fixes https://github.com/facebookresearch/pytorch3d/issues/1788 Reviewed By: MichaelRamamonjisoa Differential Revision: D61847675 fbshipit-source-id: 48400875d1d885bb3615bc9f4b3c7c3d822b67e7
This commit is contained in:
parent
c434957b2a
commit
8fe6934885
@ -353,45 +353,16 @@ def _create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
|
|||||||
# e.g. verts_per_mesh = (4, 5, 6)
|
# e.g. verts_per_mesh = (4, 5, 6)
|
||||||
# e.g. edges_per_mesh = (5, 7, 9)
|
# e.g. edges_per_mesh = (5, 7, 9)
|
||||||
|
|
||||||
V = verts_per_mesh.sum() # e.g. 15
|
rng = torch.arange(verts_per_mesh.shape[0], device=device) # (0,1,2)
|
||||||
E = edges_per_mesh.sum() # e.g. 21
|
verts_nums = rng.repeat_interleave(
|
||||||
|
verts_per_mesh
|
||||||
verts_per_mesh_cumsum = verts_per_mesh.cumsum(dim=0) # (N,) e.g. (4, 9, 15)
|
) # (0,0,0,0,1,1,1,1,1,2,2,2,2,2,2)
|
||||||
edges_per_mesh_cumsum = edges_per_mesh.cumsum(dim=0) # (N,) e.g. (5, 12, 21)
|
edges_nums = rng.repeat_interleave(
|
||||||
|
edges_per_mesh
|
||||||
v_to_e_idx = verts_per_mesh_cumsum.clone()
|
) # (0,0,0,0,0,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2)
|
||||||
|
nums = torch.cat([verts_nums, edges_nums])
|
||||||
# vertex to edge index.
|
|
||||||
v_to_e_idx[1:] += edges_per_mesh_cumsum[
|
|
||||||
:-1
|
|
||||||
] # e.g. (4, 9, 15) + (0, 5, 12) = (4, 14, 27)
|
|
||||||
|
|
||||||
# vertex to edge offset.
|
|
||||||
v_to_e_offset = V - verts_per_mesh_cumsum # e.g. 15 - (4, 9, 15) = (11, 6, 0)
|
|
||||||
v_to_e_offset[1:] += edges_per_mesh_cumsum[
|
|
||||||
:-1
|
|
||||||
] # e.g. (11, 6, 0) + (0, 5, 12) = (11, 11, 12)
|
|
||||||
e_to_v_idx = (
|
|
||||||
verts_per_mesh_cumsum[:-1] + edges_per_mesh_cumsum[:-1]
|
|
||||||
) # (4, 9) + (5, 12) = (9, 21)
|
|
||||||
e_to_v_offset = (
|
|
||||||
verts_per_mesh_cumsum[:-1] - edges_per_mesh_cumsum[:-1] - V
|
|
||||||
) # (4, 9) - (5, 12) - 15 = (-16, -18)
|
|
||||||
|
|
||||||
# Add one new vertex per edge.
|
|
||||||
idx_diffs = torch.ones(V + E, device=device, dtype=torch.int64) # (36,)
|
|
||||||
idx_diffs[v_to_e_idx] += v_to_e_offset
|
|
||||||
idx_diffs[e_to_v_idx] += e_to_v_offset
|
|
||||||
|
|
||||||
# e.g.
|
|
||||||
# [
|
|
||||||
# 1, 1, 1, 1, 12, 1, 1, 1, 1,
|
|
||||||
# -15, 1, 1, 1, 1, 12, 1, 1, 1, 1, 1, 1,
|
|
||||||
# -17, 1, 1, 1, 1, 1, 13, 1, 1, 1, 1, 1, 1, 1
|
|
||||||
# ]
|
|
||||||
|
|
||||||
verts_idx = idx_diffs.cumsum(dim=0) - 1
|
|
||||||
|
|
||||||
|
verts_idx = torch.argsort(nums, stable=True)
|
||||||
# e.g.
|
# e.g.
|
||||||
# [
|
# [
|
||||||
# 0, 1, 2, 3, 15, 16, 17, 18, 19, --> mesh 0
|
# 0, 1, 2, 3, 15, 16, 17, 18, 19, --> mesh 0
|
||||||
@ -400,7 +371,6 @@ def _create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
|
|||||||
# ]
|
# ]
|
||||||
# where for mesh 0, [0, 1, 2, 3] are the indices of the existing verts, and
|
# where for mesh 0, [0, 1, 2, 3] are the indices of the existing verts, and
|
||||||
# [15, 16, 17, 18, 19] are the indices of the new verts after subdivision.
|
# [15, 16, 17, 18, 19] are the indices of the new verts after subdivision.
|
||||||
|
|
||||||
return verts_idx
|
return verts_idx
|
||||||
|
|
||||||
|
|
||||||
@ -421,44 +391,9 @@ def _create_faces_index(faces_per_mesh: torch.Tensor, device=None):
|
|||||||
"""
|
"""
|
||||||
# e.g. faces_per_mesh = [2, 5, 3]
|
# e.g. faces_per_mesh = [2, 5, 3]
|
||||||
|
|
||||||
F = faces_per_mesh.sum() # e.g. 10
|
rng = torch.arange(faces_per_mesh.shape[0], device=device) # (0,1,2)
|
||||||
faces_per_mesh_cumsum = faces_per_mesh.cumsum(dim=0) # (N,) e.g. (2, 7, 10)
|
nums = rng.repeat_interleave(faces_per_mesh).repeat(4)
|
||||||
|
faces_idx = torch.argsort(nums, stable=True)
|
||||||
switch1_idx = faces_per_mesh_cumsum.clone()
|
|
||||||
switch1_idx[1:] += (
|
|
||||||
3 * faces_per_mesh_cumsum[:-1]
|
|
||||||
) # e.g. (2, 7, 10) + (0, 6, 21) = (2, 13, 31)
|
|
||||||
|
|
||||||
switch2_idx = 2 * faces_per_mesh_cumsum # e.g. (4, 14, 20)
|
|
||||||
switch2_idx[1:] += (
|
|
||||||
2 * faces_per_mesh_cumsum[:-1]
|
|
||||||
) # e.g. (4, 14, 20) + (0, 4, 14) = (4, 18, 34)
|
|
||||||
|
|
||||||
switch3_idx = 3 * faces_per_mesh_cumsum # e.g. (6, 21, 30)
|
|
||||||
switch3_idx[1:] += faces_per_mesh_cumsum[
|
|
||||||
:-1
|
|
||||||
] # e.g. (6, 21, 30) + (0, 2, 7) = (6, 23, 37)
|
|
||||||
|
|
||||||
switch4_idx = 4 * faces_per_mesh_cumsum[:-1] # e.g. (8, 28)
|
|
||||||
|
|
||||||
switch123_offset = F - faces_per_mesh # e.g. (8, 5, 7)
|
|
||||||
|
|
||||||
# pyre-fixme[6]: For 1st param expected `Union[List[int], Size,
|
|
||||||
# typing.Tuple[int, ...]]` but got `Tensor`.
|
|
||||||
idx_diffs = torch.ones(4 * F, device=device, dtype=torch.int64)
|
|
||||||
idx_diffs[switch1_idx] += switch123_offset
|
|
||||||
idx_diffs[switch2_idx] += switch123_offset
|
|
||||||
idx_diffs[switch3_idx] += switch123_offset
|
|
||||||
idx_diffs[switch4_idx] -= 3 * F
|
|
||||||
|
|
||||||
# e.g
|
|
||||||
# [
|
|
||||||
# 1, 1, 9, 1, 9, 1, 9, 1, -> mesh 0
|
|
||||||
# -29, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, -> mesh 1
|
|
||||||
# -29, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1 -> mesh 2
|
|
||||||
# ]
|
|
||||||
|
|
||||||
faces_idx = idx_diffs.cumsum(dim=0) - 1
|
|
||||||
|
|
||||||
# e.g.
|
# e.g.
|
||||||
# [
|
# [
|
||||||
|
@ -217,6 +217,15 @@ class TestSubdivideMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
self.assertClose(new_feats, gt_feats)
|
self.assertClose(new_feats, gt_feats)
|
||||||
self.assertTrue(new_feats.requires_grad == gt_feats.requires_grad)
|
self.assertTrue(new_feats.requires_grad == gt_feats.requires_grad)
|
||||||
|
|
||||||
|
def test_with_empty(self):
|
||||||
|
verts_list = [[[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]], []]
|
||||||
|
faces_list = [[[0, 1, 2], [0, 2, 3]], []]
|
||||||
|
verts_list = [torch.tensor(verts, dtype=torch.float64) for verts in verts_list]
|
||||||
|
face_list = [torch.tensor(faces, dtype=torch.long) for faces in faces_list]
|
||||||
|
meshes = Meshes(verts=verts_list, faces=face_list)
|
||||||
|
subdivided_meshes = SubdivideMeshes()(meshes)
|
||||||
|
self.assertEqual(len(subdivided_meshes), 2)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def subdivide_meshes_with_init(num_meshes: int = 10, same_topo: bool = False):
|
def subdivide_meshes_with_init(num_meshes: int = 10, same_topo: bool = False):
|
||||||
device = torch.device("cuda:0")
|
device = torch.device("cuda:0")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user