mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-02-07 14:52:21 +08:00
add L1 support for KNN & Chamfer
Summary: Added L1 norm for KNN and chamfer op * The norm is now specified with a variable `norm` which can only be 1 or 2 Reviewed By: bottler Differential Revision: D35419637 fbshipit-source-id: 77813fec650b30c28342af90d5ed02c89133e136
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4b94649f7b
commit
67fff956a2
@@ -43,8 +43,9 @@ class _ball_query(Function):
|
||||
p2 = p2.float()
|
||||
|
||||
# Reuse the KNN backward function
|
||||
# by default, norm is 2
|
||||
grad_p1, grad_p2 = _C.knn_points_backward(
|
||||
p1, p2, lengths1, lengths2, idx, grad_dists
|
||||
p1, p2, lengths1, lengths2, idx, 2, grad_dists
|
||||
)
|
||||
return grad_p1, grad_p2, None, None, None, None
|
||||
|
||||
|
||||
@@ -24,7 +24,15 @@ class _knn_points(Function):
|
||||
@staticmethod
|
||||
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
|
||||
def forward(
|
||||
ctx, p1, p2, lengths1, lengths2, K, version, return_sorted: bool = True
|
||||
ctx,
|
||||
p1,
|
||||
p2,
|
||||
lengths1,
|
||||
lengths2,
|
||||
K,
|
||||
version,
|
||||
norm: int = 2,
|
||||
return_sorted: bool = True,
|
||||
):
|
||||
"""
|
||||
K-Nearest neighbors on point clouds.
|
||||
@@ -43,6 +51,7 @@ class _knn_points(Function):
|
||||
K: Integer giving the number of nearest neighbors to return.
|
||||
version: Which KNN implementation to use in the backend. If version=-1,
|
||||
the correct implementation is selected based on the shapes of the inputs.
|
||||
norm: (int) indicating the norm. Only supports 1 (for L1) and 2 (for L2).
|
||||
return_sorted: (bool) whether to return the nearest neighbors sorted in
|
||||
ascending order of distance.
|
||||
|
||||
@@ -57,8 +66,10 @@ class _knn_points(Function):
|
||||
neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
|
||||
in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points.
|
||||
"""
|
||||
if not ((norm == 1) or (norm == 2)):
|
||||
raise ValueError("Support for 1 or 2 norm.")
|
||||
|
||||
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version)
|
||||
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version)
|
||||
|
||||
# sort KNN in ascending order if K > 1
|
||||
if K > 1 and return_sorted:
|
||||
@@ -78,12 +89,14 @@ class _knn_points(Function):
|
||||
|
||||
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
|
||||
ctx.mark_non_differentiable(idx)
|
||||
ctx.norm = norm
|
||||
return dists, idx
|
||||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
def backward(ctx, grad_dists, grad_idx):
|
||||
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
|
||||
norm = ctx.norm
|
||||
# TODO(gkioxari) Change cast to floats once we add support for doubles.
|
||||
if not (grad_dists.dtype == torch.float32):
|
||||
grad_dists = grad_dists.float()
|
||||
@@ -92,9 +105,9 @@ class _knn_points(Function):
|
||||
if not (p2.dtype == torch.float32):
|
||||
p2 = p2.float()
|
||||
grad_p1, grad_p2 = _C.knn_points_backward(
|
||||
p1, p2, lengths1, lengths2, idx, grad_dists
|
||||
p1, p2, lengths1, lengths2, idx, norm, grad_dists
|
||||
)
|
||||
return grad_p1, grad_p2, None, None, None, None, None
|
||||
return grad_p1, grad_p2, None, None, None, None, None, None
|
||||
|
||||
|
||||
def knn_points(
|
||||
@@ -102,6 +115,7 @@ def knn_points(
|
||||
p2: torch.Tensor,
|
||||
lengths1: Union[torch.Tensor, None] = None,
|
||||
lengths2: Union[torch.Tensor, None] = None,
|
||||
norm: int = 2,
|
||||
K: int = 1,
|
||||
version: int = -1,
|
||||
return_nn: bool = False,
|
||||
@@ -121,6 +135,7 @@ def knn_points(
|
||||
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
|
||||
length of each pointcloud in p2. Or None to indicate that every cloud has
|
||||
length P2.
|
||||
norm: Integer indicating the norm of the distance. Supports only 1 for L1, 2 for L2.
|
||||
K: Integer giving the number of nearest neighbors to return.
|
||||
version: Which KNN implementation to use in the backend. If version=-1,
|
||||
the correct implementation is selected based on the shapes of the inputs.
|
||||
@@ -172,7 +187,7 @@ def knn_points(
|
||||
|
||||
# pyre-fixme[16]: `_knn_points` has no attribute `apply`.
|
||||
p1_dists, p1_idx = _knn_points.apply(
|
||||
p1, p2, lengths1, lengths2, K, version, return_sorted
|
||||
p1, p2, lengths1, lengths2, K, version, norm, return_sorted
|
||||
)
|
||||
|
||||
p2_nn = None
|
||||
|
||||
Reference in New Issue
Block a user