diff --git a/pytorch3d/loss/chamfer.py b/pytorch3d/loss/chamfer.py index dc2158fa..49690ec3 100644 --- a/pytorch3d/loss/chamfer.py +++ b/pytorch3d/loss/chamfer.py @@ -48,10 +48,11 @@ def _handle_pointcloud_input( if points.ndim != 3: raise ValueError("Expected points to be of shape (N, P, D)") X = points - if lengths is not None and ( - lengths.ndim != 1 or lengths.shape[0] != X.shape[0] - ): - raise ValueError("Expected lengths to be of shape (N,)") + if lengths is not None: + if lengths.ndim != 1 or lengths.shape[0] != X.shape[0]: + raise ValueError("Expected lengths to be of shape (N,)") + if lengths.max() > X.shape[1]: + raise ValueError("A length value was too long") if lengths is None: lengths = torch.full( (X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device diff --git a/pytorch3d/ops/sample_farthest_points.py b/pytorch3d/ops/sample_farthest_points.py index 92f3efcd..a2ff2e35 100644 --- a/pytorch3d/ops/sample_farthest_points.py +++ b/pytorch3d/ops/sample_farthest_points.py @@ -56,9 +56,11 @@ def sample_farthest_points( # Validate inputs if lengths is None: lengths = torch.full((N,), P, dtype=torch.int64, device=device) - - if lengths.shape != (N,): - raise ValueError("points and lengths must have same batch dimension.") + else: + if lengths.shape != (N,): + raise ValueError("points and lengths must have same batch dimension.") + if lengths.max() > P: + raise ValueError("A value in lengths was too large.") # TODO: support providing K as a ratio of the total number of points instead of as an int if isinstance(K, int): @@ -107,9 +109,11 @@ def sample_farthest_points_naive( # Validate inputs if lengths is None: lengths = torch.full((N,), P, dtype=torch.int64, device=device) - - if lengths.shape[0] != N: - raise ValueError("points and lengths must have same batch dimension.") + else: + if lengths.shape != (N,): + raise ValueError("points and lengths must have same batch dimension.") + if lengths.max() > P: + raise ValueError("Invalid lengths.") # TODO: support providing K as a ratio of the total number of points instead of as an int if isinstance(K, int):