validate lengths in chamfer and farthest_points

Summary: Fixes #1326

Reviewed By: kjchalup

Differential Revision: D39259697

fbshipit-source-id: 51392f4cc4a956165a62901cb115fcefe0e17277
This commit is contained in:
Jeremy Reizenstein
2022-09-08 15:03:36 -07:00
committed by Facebook GitHub Bot
parent 6e25fe8cb3
commit cb7bd33e7f
2 changed files with 15 additions and 10 deletions

View File

@@ -56,9 +56,11 @@ def sample_farthest_points(
# Validate inputs
if lengths is None:
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
if lengths.shape != (N,):
raise ValueError("points and lengths must have same batch dimension.")
else:
if lengths.shape != (N,):
raise ValueError("points and lengths must have same batch dimension.")
if lengths.max() > P:
raise ValueError("A value in lengths was too large.")
# TODO: support providing K as a ratio of the total number of points instead of as an int
if isinstance(K, int):
@@ -107,9 +109,11 @@ def sample_farthest_points_naive(
# Validate inputs
if lengths is None:
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
if lengths.shape[0] != N:
raise ValueError("points and lengths must have same batch dimension.")
else:
if lengths.shape != (N,):
raise ValueError("points and lengths must have same batch dimension.")
if lengths.max() > P:
raise ValueError("Invalid lengths.")
# TODO: support providing K as a ratio of the total number of points instead of as an int
if isinstance(K, int):