mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	validate lengths in chamfer and farthest_points
Summary: Fixes #1326 Reviewed By: kjchalup Differential Revision: D39259697 fbshipit-source-id: 51392f4cc4a956165a62901cb115fcefe0e17277
This commit is contained in:
		
							parent
							
								
									6e25fe8cb3
								
							
						
					
					
						commit
						cb7bd33e7f
					
				@ -48,10 +48,11 @@ def _handle_pointcloud_input(
 | 
			
		||||
        if points.ndim != 3:
 | 
			
		||||
            raise ValueError("Expected points to be of shape (N, P, D)")
 | 
			
		||||
        X = points
 | 
			
		||||
        if lengths is not None and (
 | 
			
		||||
            lengths.ndim != 1 or lengths.shape[0] != X.shape[0]
 | 
			
		||||
        ):
 | 
			
		||||
            raise ValueError("Expected lengths to be of shape (N,)")
 | 
			
		||||
        if lengths is not None:
 | 
			
		||||
            if lengths.ndim != 1 or lengths.shape[0] != X.shape[0]:
 | 
			
		||||
                raise ValueError("Expected lengths to be of shape (N,)")
 | 
			
		||||
            if lengths.max() > X.shape[1]:
 | 
			
		||||
                raise ValueError("A length value was too long")
 | 
			
		||||
        if lengths is None:
 | 
			
		||||
            lengths = torch.full(
 | 
			
		||||
                (X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device
 | 
			
		||||
 | 
			
		||||
@ -56,9 +56,11 @@ def sample_farthest_points(
 | 
			
		||||
    # Validate inputs
 | 
			
		||||
    if lengths is None:
 | 
			
		||||
        lengths = torch.full((N,), P, dtype=torch.int64, device=device)
 | 
			
		||||
 | 
			
		||||
    if lengths.shape != (N,):
 | 
			
		||||
        raise ValueError("points and lengths must have same batch dimension.")
 | 
			
		||||
    else:
 | 
			
		||||
        if lengths.shape != (N,):
 | 
			
		||||
            raise ValueError("points and lengths must have same batch dimension.")
 | 
			
		||||
        if lengths.max() > P:
 | 
			
		||||
            raise ValueError("A value in lengths was too large.")
 | 
			
		||||
 | 
			
		||||
    # TODO: support providing K as a ratio of the total number of points instead of as an int
 | 
			
		||||
    if isinstance(K, int):
 | 
			
		||||
@ -107,9 +109,11 @@ def sample_farthest_points_naive(
 | 
			
		||||
    # Validate inputs
 | 
			
		||||
    if lengths is None:
 | 
			
		||||
        lengths = torch.full((N,), P, dtype=torch.int64, device=device)
 | 
			
		||||
 | 
			
		||||
    if lengths.shape[0] != N:
 | 
			
		||||
        raise ValueError("points and lengths must have same batch dimension.")
 | 
			
		||||
    else:
 | 
			
		||||
        if lengths.shape != (N,):
 | 
			
		||||
            raise ValueError("points and lengths must have same batch dimension.")
 | 
			
		||||
        if lengths.max() > P:
 | 
			
		||||
            raise ValueError("Invalid lengths.")
 | 
			
		||||
 | 
			
		||||
    # TODO: support providing K as a ratio of the total number of points instead of as an int
 | 
			
		||||
    if isinstance(K, int):
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user