mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	fix subdivide_meshes with empty mesh #1788
Summary: Simplify code fixes https://github.com/facebookresearch/pytorch3d/issues/1788 Reviewed By: MichaelRamamonjisoa Differential Revision: D61847675 fbshipit-source-id: 48400875d1d885bb3615bc9f4b3c7c3d822b67e7
This commit is contained in:
		
							parent
							
								
									c434957b2a
								
							
						
					
					
						commit
						8fe6934885
					
				@ -353,45 +353,16 @@ def _create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
 | 
				
			|||||||
    # e.g. verts_per_mesh = (4, 5, 6)
 | 
					    # e.g. verts_per_mesh = (4, 5, 6)
 | 
				
			||||||
    # e.g. edges_per_mesh = (5, 7, 9)
 | 
					    # e.g. edges_per_mesh = (5, 7, 9)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    V = verts_per_mesh.sum()  # e.g. 15
 | 
					    rng = torch.arange(verts_per_mesh.shape[0], device=device)  # (0,1,2)
 | 
				
			||||||
    E = edges_per_mesh.sum()  # e.g. 21
 | 
					    verts_nums = rng.repeat_interleave(
 | 
				
			||||||
 | 
					        verts_per_mesh
 | 
				
			||||||
    verts_per_mesh_cumsum = verts_per_mesh.cumsum(dim=0)  # (N,) e.g. (4, 9, 15)
 | 
					    )  # (0,0,0,0,1,1,1,1,1,2,2,2,2,2,2)
 | 
				
			||||||
    edges_per_mesh_cumsum = edges_per_mesh.cumsum(dim=0)  # (N,) e.g. (5, 12, 21)
 | 
					    edges_nums = rng.repeat_interleave(
 | 
				
			||||||
 | 
					        edges_per_mesh
 | 
				
			||||||
    v_to_e_idx = verts_per_mesh_cumsum.clone()
 | 
					    )  # (0,0,0,0,0,1,1,1,1,1,1,1,2,2,2,2,2,2,2,2,2)
 | 
				
			||||||
 | 
					    nums = torch.cat([verts_nums, edges_nums])
 | 
				
			||||||
    # vertex to edge index.
 | 
					 | 
				
			||||||
    v_to_e_idx[1:] += edges_per_mesh_cumsum[
 | 
					 | 
				
			||||||
        :-1
 | 
					 | 
				
			||||||
    ]  # e.g. (4, 9, 15) + (0, 5, 12) = (4, 14, 27)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # vertex to edge offset.
 | 
					 | 
				
			||||||
    v_to_e_offset = V - verts_per_mesh_cumsum  # e.g. 15 - (4, 9, 15) = (11, 6, 0)
 | 
					 | 
				
			||||||
    v_to_e_offset[1:] += edges_per_mesh_cumsum[
 | 
					 | 
				
			||||||
        :-1
 | 
					 | 
				
			||||||
    ]  # e.g. (11, 6, 0) + (0, 5, 12) = (11, 11, 12)
 | 
					 | 
				
			||||||
    e_to_v_idx = (
 | 
					 | 
				
			||||||
        verts_per_mesh_cumsum[:-1] + edges_per_mesh_cumsum[:-1]
 | 
					 | 
				
			||||||
    )  # (4, 9) + (5, 12) = (9, 21)
 | 
					 | 
				
			||||||
    e_to_v_offset = (
 | 
					 | 
				
			||||||
        verts_per_mesh_cumsum[:-1] - edges_per_mesh_cumsum[:-1] - V
 | 
					 | 
				
			||||||
    )  # (4, 9) - (5, 12) - 15 = (-16, -18)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # Add one new vertex per edge.
 | 
					 | 
				
			||||||
    idx_diffs = torch.ones(V + E, device=device, dtype=torch.int64)  # (36,)
 | 
					 | 
				
			||||||
    idx_diffs[v_to_e_idx] += v_to_e_offset
 | 
					 | 
				
			||||||
    idx_diffs[e_to_v_idx] += e_to_v_offset
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # e.g.
 | 
					 | 
				
			||||||
    # [
 | 
					 | 
				
			||||||
    #  1, 1, 1, 1, 12, 1, 1, 1, 1,
 | 
					 | 
				
			||||||
    #  -15, 1, 1, 1, 1, 12, 1, 1, 1, 1, 1, 1,
 | 
					 | 
				
			||||||
    #  -17, 1, 1, 1, 1, 1, 13, 1, 1, 1, 1, 1, 1, 1
 | 
					 | 
				
			||||||
    # ]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    verts_idx = idx_diffs.cumsum(dim=0) - 1
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    verts_idx = torch.argsort(nums, stable=True)
 | 
				
			||||||
    # e.g.
 | 
					    # e.g.
 | 
				
			||||||
    # [
 | 
					    # [
 | 
				
			||||||
    #  0, 1, 2, 3, 15, 16, 17, 18, 19,                            --> mesh 0
 | 
					    #  0, 1, 2, 3, 15, 16, 17, 18, 19,                            --> mesh 0
 | 
				
			||||||
@ -400,7 +371,6 @@ def _create_verts_index(verts_per_mesh, edges_per_mesh, device=None):
 | 
				
			|||||||
    # ]
 | 
					    # ]
 | 
				
			||||||
    # where for mesh 0, [0, 1, 2, 3] are the indices of the existing verts, and
 | 
					    # where for mesh 0, [0, 1, 2, 3] are the indices of the existing verts, and
 | 
				
			||||||
    # [15, 16, 17, 18, 19] are the indices of the new verts after subdivision.
 | 
					    # [15, 16, 17, 18, 19] are the indices of the new verts after subdivision.
 | 
				
			||||||
 | 
					 | 
				
			||||||
    return verts_idx
 | 
					    return verts_idx
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -421,44 +391,9 @@ def _create_faces_index(faces_per_mesh: torch.Tensor, device=None):
 | 
				
			|||||||
    """
 | 
					    """
 | 
				
			||||||
    # e.g. faces_per_mesh = [2, 5, 3]
 | 
					    # e.g. faces_per_mesh = [2, 5, 3]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    F = faces_per_mesh.sum()  # e.g. 10
 | 
					    rng = torch.arange(faces_per_mesh.shape[0], device=device)  # (0,1,2)
 | 
				
			||||||
    faces_per_mesh_cumsum = faces_per_mesh.cumsum(dim=0)  # (N,) e.g. (2, 7, 10)
 | 
					    nums = rng.repeat_interleave(faces_per_mesh).repeat(4)
 | 
				
			||||||
 | 
					    faces_idx = torch.argsort(nums, stable=True)
 | 
				
			||||||
    switch1_idx = faces_per_mesh_cumsum.clone()
 | 
					 | 
				
			||||||
    switch1_idx[1:] += (
 | 
					 | 
				
			||||||
        3 * faces_per_mesh_cumsum[:-1]
 | 
					 | 
				
			||||||
    )  # e.g. (2, 7, 10) + (0, 6, 21) = (2, 13, 31)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    switch2_idx = 2 * faces_per_mesh_cumsum  # e.g. (4, 14, 20)
 | 
					 | 
				
			||||||
    switch2_idx[1:] += (
 | 
					 | 
				
			||||||
        2 * faces_per_mesh_cumsum[:-1]
 | 
					 | 
				
			||||||
    )  # e.g. (4, 14, 20) + (0, 4, 14) = (4, 18, 34)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    switch3_idx = 3 * faces_per_mesh_cumsum  # e.g. (6, 21, 30)
 | 
					 | 
				
			||||||
    switch3_idx[1:] += faces_per_mesh_cumsum[
 | 
					 | 
				
			||||||
        :-1
 | 
					 | 
				
			||||||
    ]  # e.g. (6, 21, 30) + (0, 2, 7) = (6, 23, 37)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    switch4_idx = 4 * faces_per_mesh_cumsum[:-1]  # e.g. (8, 28)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    switch123_offset = F - faces_per_mesh  # e.g. (8, 5, 7)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # pyre-fixme[6]: For 1st param expected `Union[List[int], Size,
 | 
					 | 
				
			||||||
    #  typing.Tuple[int, ...]]` but got `Tensor`.
 | 
					 | 
				
			||||||
    idx_diffs = torch.ones(4 * F, device=device, dtype=torch.int64)
 | 
					 | 
				
			||||||
    idx_diffs[switch1_idx] += switch123_offset
 | 
					 | 
				
			||||||
    idx_diffs[switch2_idx] += switch123_offset
 | 
					 | 
				
			||||||
    idx_diffs[switch3_idx] += switch123_offset
 | 
					 | 
				
			||||||
    idx_diffs[switch4_idx] -= 3 * F
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # e.g
 | 
					 | 
				
			||||||
    # [
 | 
					 | 
				
			||||||
    #  1, 1, 9, 1, 9, 1, 9, 1,                                       -> mesh 0
 | 
					 | 
				
			||||||
    #  -29, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, 6, 1, 1, 1, 1, -> mesh 1
 | 
					 | 
				
			||||||
    #  -29, 1, 1, 8, 1, 1, 8, 1, 1, 8, 1, 1                          -> mesh 2
 | 
					 | 
				
			||||||
    # ]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    faces_idx = idx_diffs.cumsum(dim=0) - 1
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # e.g.
 | 
					    # e.g.
 | 
				
			||||||
    # [
 | 
					    # [
 | 
				
			||||||
 | 
				
			|||||||
@ -217,6 +217,15 @@ class TestSubdivideMeshes(TestCaseMixin, unittest.TestCase):
 | 
				
			|||||||
        self.assertClose(new_feats, gt_feats)
 | 
					        self.assertClose(new_feats, gt_feats)
 | 
				
			||||||
        self.assertTrue(new_feats.requires_grad == gt_feats.requires_grad)
 | 
					        self.assertTrue(new_feats.requires_grad == gt_feats.requires_grad)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_with_empty(self):
 | 
				
			||||||
 | 
					        verts_list = [[[0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]], []]
 | 
				
			||||||
 | 
					        faces_list = [[[0, 1, 2], [0, 2, 3]], []]
 | 
				
			||||||
 | 
					        verts_list = [torch.tensor(verts, dtype=torch.float64) for verts in verts_list]
 | 
				
			||||||
 | 
					        face_list = [torch.tensor(faces, dtype=torch.long) for faces in faces_list]
 | 
				
			||||||
 | 
					        meshes = Meshes(verts=verts_list, faces=face_list)
 | 
				
			||||||
 | 
					        subdivided_meshes = SubdivideMeshes()(meshes)
 | 
				
			||||||
 | 
					        self.assertEqual(len(subdivided_meshes), 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def subdivide_meshes_with_init(num_meshes: int = 10, same_topo: bool = False):
 | 
					    def subdivide_meshes_with_init(num_meshes: int = 10, same_topo: bool = False):
 | 
				
			||||||
        device = torch.device("cuda:0")
 | 
					        device = torch.device("cuda:0")
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user