mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
validate lengths in chamfer and farthest_points
Summary: Fixes #1326 Reviewed By: kjchalup Differential Revision: D39259697 fbshipit-source-id: 51392f4cc4a956165a62901cb115fcefe0e17277
This commit is contained in:
parent
6e25fe8cb3
commit
cb7bd33e7f
@ -48,10 +48,11 @@ def _handle_pointcloud_input(
|
|||||||
if points.ndim != 3:
|
if points.ndim != 3:
|
||||||
raise ValueError("Expected points to be of shape (N, P, D)")
|
raise ValueError("Expected points to be of shape (N, P, D)")
|
||||||
X = points
|
X = points
|
||||||
if lengths is not None and (
|
if lengths is not None:
|
||||||
lengths.ndim != 1 or lengths.shape[0] != X.shape[0]
|
if lengths.ndim != 1 or lengths.shape[0] != X.shape[0]:
|
||||||
):
|
raise ValueError("Expected lengths to be of shape (N,)")
|
||||||
raise ValueError("Expected lengths to be of shape (N,)")
|
if lengths.max() > X.shape[1]:
|
||||||
|
raise ValueError("A length value was too long")
|
||||||
if lengths is None:
|
if lengths is None:
|
||||||
lengths = torch.full(
|
lengths = torch.full(
|
||||||
(X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device
|
(X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device
|
||||||
|
@ -56,9 +56,11 @@ def sample_farthest_points(
|
|||||||
# Validate inputs
|
# Validate inputs
|
||||||
if lengths is None:
|
if lengths is None:
|
||||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||||
|
else:
|
||||||
if lengths.shape != (N,):
|
if lengths.shape != (N,):
|
||||||
raise ValueError("points and lengths must have same batch dimension.")
|
raise ValueError("points and lengths must have same batch dimension.")
|
||||||
|
if lengths.max() > P:
|
||||||
|
raise ValueError("A value in lengths was too large.")
|
||||||
|
|
||||||
# TODO: support providing K as a ratio of the total number of points instead of as an int
|
# TODO: support providing K as a ratio of the total number of points instead of as an int
|
||||||
if isinstance(K, int):
|
if isinstance(K, int):
|
||||||
@ -107,9 +109,11 @@ def sample_farthest_points_naive(
|
|||||||
# Validate inputs
|
# Validate inputs
|
||||||
if lengths is None:
|
if lengths is None:
|
||||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||||
|
else:
|
||||||
if lengths.shape[0] != N:
|
if lengths.shape != (N,):
|
||||||
raise ValueError("points and lengths must have same batch dimension.")
|
raise ValueError("points and lengths must have same batch dimension.")
|
||||||
|
if lengths.max() > P:
|
||||||
|
raise ValueError("Invalid lengths.")
|
||||||
|
|
||||||
# TODO: support providing K as a ratio of the total number of points instead of as an int
|
# TODO: support providing K as a ratio of the total number of points instead of as an int
|
||||||
if isinstance(K, int):
|
if isinstance(K, int):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user