add L1 support for KNN & Chamfer

Summary:
Added L1 norm for KNN and chamfer op
* The norm is now specified with a variable `norm` which can only be 1 or 2

Reviewed By: bottler

Differential Revision: D35419637

fbshipit-source-id: 77813fec650b30c28342af90d5ed02c89133e136
This commit is contained in:
Georgia Gkioxari
2022-04-10 10:27:20 -07:00
committed by Facebook GitHub Bot
parent 4b94649f7b
commit 67fff956a2
8 changed files with 265 additions and 129 deletions

View File

@@ -43,8 +43,9 @@ class _ball_query(Function):
p2 = p2.float()
# Reuse the KNN backward function
# by default, norm is 2
grad_p1, grad_p2 = _C.knn_points_backward(
p1, p2, lengths1, lengths2, idx, grad_dists
p1, p2, lengths1, lengths2, idx, 2, grad_dists
)
return grad_p1, grad_p2, None, None, None, None

View File

@@ -24,7 +24,15 @@ class _knn_points(Function):
@staticmethod
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
def forward(
ctx, p1, p2, lengths1, lengths2, K, version, return_sorted: bool = True
ctx,
p1,
p2,
lengths1,
lengths2,
K,
version,
norm: int = 2,
return_sorted: bool = True,
):
"""
K-Nearest neighbors on point clouds.
@@ -43,6 +51,7 @@ class _knn_points(Function):
K: Integer giving the number of nearest neighbors to return.
version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of the inputs.
norm: (int) indicating the norm. Only supports 1 (for L1) and 2 (for L2).
return_sorted: (bool) whether to return the nearest neighbors sorted in
ascending order of distance.
@@ -57,8 +66,10 @@ class _knn_points(Function):
neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points.
"""
if not ((norm == 1) or (norm == 2)):
raise ValueError("Support for 1 or 2 norm.")
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version)
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version)
# sort KNN in ascending order if K > 1
if K > 1 and return_sorted:
@@ -78,12 +89,14 @@ class _knn_points(Function):
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
ctx.mark_non_differentiable(idx)
ctx.norm = norm
return dists, idx
@staticmethod
@once_differentiable
def backward(ctx, grad_dists, grad_idx):
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
norm = ctx.norm
# TODO(gkioxari) Change cast to floats once we add support for doubles.
if not (grad_dists.dtype == torch.float32):
grad_dists = grad_dists.float()
@@ -92,9 +105,9 @@ class _knn_points(Function):
if not (p2.dtype == torch.float32):
p2 = p2.float()
grad_p1, grad_p2 = _C.knn_points_backward(
p1, p2, lengths1, lengths2, idx, grad_dists
p1, p2, lengths1, lengths2, idx, norm, grad_dists
)
return grad_p1, grad_p2, None, None, None, None, None
return grad_p1, grad_p2, None, None, None, None, None, None
def knn_points(
@@ -102,6 +115,7 @@ def knn_points(
p2: torch.Tensor,
lengths1: Union[torch.Tensor, None] = None,
lengths2: Union[torch.Tensor, None] = None,
norm: int = 2,
K: int = 1,
version: int = -1,
return_nn: bool = False,
@@ -121,6 +135,7 @@ def knn_points(
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
length of each pointcloud in p2. Or None to indicate that every cloud has
length P2.
norm: Integer indicating the norm of the distance. Supports only 1 for L1, 2 for L2.
K: Integer giving the number of nearest neighbors to return.
version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of the inputs.
@@ -172,7 +187,7 @@ def knn_points(
# pyre-fixme[16]: `_knn_points` has no attribute `apply`.
p1_dists, p1_idx = _knn_points.apply(
p1, p2, lengths1, lengths2, K, version, return_sorted
p1, p2, lengths1, lengths2, K, version, norm, return_sorted
)
p2_nn = None