add L1 support for KNN & Chamfer

Summary:
Added L1 norm for KNN and chamfer op
* The norm is now specified with a variable `norm` which can only be 1 or 2

Reviewed By: bottler

Differential Revision: D35419637

fbshipit-source-id: 77813fec650b30c28342af90d5ed02c89133e136
This commit is contained in:
Georgia Gkioxari 2022-04-10 10:27:20 -07:00 committed by Facebook GitHub Bot
parent 4b94649f7b
commit 67fff956a2
8 changed files with 265 additions and 129 deletions

View File

@ -36,7 +36,8 @@ __global__ void KNearestNeighborKernelV0(
const size_t P1, const size_t P1,
const size_t P2, const size_t P2,
const size_t D, const size_t D,
const size_t K) { const size_t K,
const size_t norm) {
// Store both dists and indices for knn in global memory. // Store both dists and indices for knn in global memory.
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud; const int64_t chunks_to_do = N * chunks_per_cloud;
@ -56,7 +57,8 @@ __global__ void KNearestNeighborKernelV0(
scalar_t coord1 = points1[n * P1 * D + p1 * D + d]; scalar_t coord1 = points1[n * P1 * D + p1 * D + d];
scalar_t coord2 = points2[n * P2 * D + p2 * D + d]; scalar_t coord2 = points2[n * P2 * D + p2 * D + d];
scalar_t diff = coord1 - coord2; scalar_t diff = coord1 - coord2;
dist += diff * diff; scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
} }
mink.add(dist, p2); mink.add(dist, p2);
} }
@ -74,7 +76,8 @@ __global__ void KNearestNeighborKernelV1(
const size_t N, const size_t N,
const size_t P1, const size_t P1,
const size_t P2, const size_t P2,
const size_t K) { const size_t K,
const size_t norm) {
// Same idea as the previous version, but hoist D into a template argument // Same idea as the previous version, but hoist D into a template argument
// so we can cache the current point in a thread-local array. We still store // so we can cache the current point in a thread-local array. We still store
// the current best K dists and indices in global memory, so this should work // the current best K dists and indices in global memory, so this should work
@ -99,7 +102,8 @@ __global__ void KNearestNeighborKernelV1(
scalar_t dist = 0; scalar_t dist = 0;
for (int d = 0; d < D; ++d) { for (int d = 0; d < D; ++d) {
scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d]; scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d];
dist += diff * diff; scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
} }
mink.add(dist, p2); mink.add(dist, p2);
} }
@ -121,10 +125,11 @@ struct KNearestNeighborV1Functor {
const size_t N, const size_t N,
const size_t P1, const size_t P1,
const size_t P2, const size_t P2,
const size_t K) { const size_t K,
const size_t norm) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads, 0, stream>>>( KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K); points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K, norm);
} }
}; };
@ -138,7 +143,8 @@ __global__ void KNearestNeighborKernelV2(
int64_t* __restrict__ idxs, int64_t* __restrict__ idxs,
const int64_t N, const int64_t N,
const int64_t P1, const int64_t P1,
const int64_t P2) { const int64_t P2,
const size_t norm) {
// Same general implementation as V2, but also hoist K into a template arg. // Same general implementation as V2, but also hoist K into a template arg.
scalar_t cur_point[D]; scalar_t cur_point[D];
scalar_t min_dists[K]; scalar_t min_dists[K];
@ -161,7 +167,8 @@ __global__ void KNearestNeighborKernelV2(
for (int d = 0; d < D; ++d) { for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d; int offset = n * P2 * D + p2 * D + d;
scalar_t diff = cur_point[d] - points2[offset]; scalar_t diff = cur_point[d] - points2[offset];
dist += diff * diff; scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
} }
mink.add(dist, p2); mink.add(dist, p2);
} }
@ -186,10 +193,11 @@ struct KNearestNeighborKernelV2Functor {
int64_t* __restrict__ idxs, int64_t* __restrict__ idxs,
const int64_t N, const int64_t N,
const int64_t P1, const int64_t P1,
const int64_t P2) { const int64_t P2,
const size_t norm) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads, 0, stream>>>( KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2); points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
} }
}; };
@ -203,7 +211,8 @@ __global__ void KNearestNeighborKernelV3(
int64_t* __restrict__ idxs, int64_t* __restrict__ idxs,
const size_t N, const size_t N,
const size_t P1, const size_t P1,
const size_t P2) { const size_t P2,
const size_t norm) {
// Same idea as V2, but use register indexing for thread-local arrays. // Same idea as V2, but use register indexing for thread-local arrays.
// Enabling sorting for this version leads to huge slowdowns; I suspect // Enabling sorting for this version leads to huge slowdowns; I suspect
// that it forces min_dists into local memory rather than registers. // that it forces min_dists into local memory rather than registers.
@ -229,7 +238,8 @@ __global__ void KNearestNeighborKernelV3(
for (int d = 0; d < D; ++d) { for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d; int offset = n * P2 * D + p2 * D + d;
scalar_t diff = cur_point[d] - points2[offset]; scalar_t diff = cur_point[d] - points2[offset];
dist += diff * diff; scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
} }
mink.add(dist, p2); mink.add(dist, p2);
} }
@ -254,10 +264,11 @@ struct KNearestNeighborKernelV3Functor {
int64_t* __restrict__ idxs, int64_t* __restrict__ idxs,
const size_t N, const size_t N,
const size_t P1, const size_t P1,
const size_t P2) { const size_t P2,
const size_t norm) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads, 0, stream>>>( KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2); points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
} }
}; };
@ -305,7 +316,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const at::Tensor& p2, const at::Tensor& p2,
const at::Tensor& lengths1, const at::Tensor& lengths1,
const at::Tensor& lengths2, const at::Tensor& lengths2,
int K, const int norm,
const int K,
int version) { int version) {
// Check inputs are on the same device // Check inputs are on the same device
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
@ -324,6 +336,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const auto D = p2.size(2); const auto D = p2.size(2);
const int64_t K_64 = K; const int64_t K_64 = K;
TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2.");
TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension"); TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension");
auto long_dtype = lengths1.options().dtype(at::kLong); auto long_dtype = lengths1.options().dtype(at::kLong);
auto idxs = at::zeros({N, P1, K}, long_dtype); auto idxs = at::zeros({N, P1, K}, long_dtype);
@ -366,7 +380,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
P1, P1,
P2, P2,
D, D,
K); K,
norm);
})); }));
} else if (version == 1) { } else if (version == 1) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
@ -387,7 +402,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
N, N,
P1, P1,
P2, P2,
K); K,
norm);
})); }));
} else if (version == 2) { } else if (version == 2) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
@ -410,7 +426,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
idxs.data_ptr<int64_t>(), idxs.data_ptr<int64_t>(),
N, N,
P1, P1,
P2); P2,
norm);
})); }));
} else if (version == 3) { } else if (version == 3) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
@ -433,7 +450,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
idxs.data_ptr<int64_t>(), idxs.data_ptr<int64_t>(),
N, N,
P1, P1,
P2); P2,
norm);
})); }));
} }
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
@ -459,7 +477,8 @@ __global__ void KNearestNeighborBackwardKernel(
const size_t P1, const size_t P1,
const size_t P2, const size_t P2,
const size_t K, const size_t K,
const size_t D) { const size_t D,
const size_t norm) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = gridDim.x * blockDim.x; const size_t stride = gridDim.x * blockDim.x;
@ -481,8 +500,17 @@ __global__ void KNearestNeighborBackwardKernel(
if (p2_idx == -1) { if (p2_idx == -1) {
continue; continue;
} }
const float diff = 2.0 * grad_dist * float diff = 0.0;
if (norm == 1) {
float sign =
(p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d])
? 1.0
: -1.0;
diff = grad_dist * sign;
} else { // norm is 2
diff = 2.0 * grad_dist *
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]); (p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
}
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff); atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff); atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
} }
@ -495,6 +523,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
const at::Tensor& lengths1, const at::Tensor& lengths1,
const at::Tensor& lengths2, const at::Tensor& lengths2,
const at::Tensor& idxs, const at::Tensor& idxs,
int norm,
const at::Tensor& grad_dists) { const at::Tensor& grad_dists) {
// Check inputs are on the same device // Check inputs are on the same device
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
@ -547,7 +576,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
P1, P1,
P2, P2,
K, K,
D); D,
norm);
AT_CUDA_CHECK(cudaGetLastError()); AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_p1, grad_p2); return std::make_tuple(grad_p1, grad_p2);

View File

@ -21,6 +21,7 @@
// containing P2 points of dimension D. // containing P2 points of dimension D.
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud. // lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud. // lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
// K: int giving the number of nearest points to return. // K: int giving the number of nearest points to return.
// version: Integer telling which implementation to use. // version: Integer telling which implementation to use.
// //
@ -41,7 +42,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
const at::Tensor& p2, const at::Tensor& p2,
const at::Tensor& lengths1, const at::Tensor& lengths1,
const at::Tensor& lengths2, const at::Tensor& lengths2,
int K); const int norm,
const int K);
// CUDA implementation // CUDA implementation
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda( std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
@ -49,8 +51,9 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const at::Tensor& p2, const at::Tensor& p2,
const at::Tensor& lengths1, const at::Tensor& lengths1,
const at::Tensor& lengths2, const at::Tensor& lengths2,
int K, const int norm,
int version); const int K,
const int version);
// Implementation which is exposed. // Implementation which is exposed.
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx( std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
@ -58,18 +61,20 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
const at::Tensor& p2, const at::Tensor& p2,
const at::Tensor& lengths1, const at::Tensor& lengths1,
const at::Tensor& lengths2, const at::Tensor& lengths2,
int K, const int norm,
int version) { const int K,
const int version) {
if (p1.is_cuda() || p2.is_cuda()) { if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CUDA(p1); CHECK_CUDA(p1);
CHECK_CUDA(p2); CHECK_CUDA(p2);
return KNearestNeighborIdxCuda(p1, p2, lengths1, lengths2, K, version); return KNearestNeighborIdxCuda(
p1, p2, lengths1, lengths2, norm, K, version);
#else #else
AT_ERROR("Not compiled with GPU support."); AT_ERROR("Not compiled with GPU support.");
#endif #endif
} }
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K); return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K);
} }
// Compute gradients with respect to p1 and p2 // Compute gradients with respect to p1 and p2
@ -86,6 +91,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j]. // neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
// It is padded with zeros so that it can be used easily in a later // It is padded with zeros so that it can be used easily in a later
// gather() operation. This is computed from the forward pass. // gather() operation. This is computed from the forward pass.
// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input // grad_dists: FLoatTensor of shape (N, P1, K) which contains the input
// gradients. // gradients.
// //
@ -102,6 +108,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
const at::Tensor& lengths1, const at::Tensor& lengths1,
const at::Tensor& lengths2, const at::Tensor& lengths2,
const at::Tensor& idxs, const at::Tensor& idxs,
const int norm,
const at::Tensor& grad_dists); const at::Tensor& grad_dists);
// CUDA implementation // CUDA implementation
@ -111,6 +118,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
const at::Tensor& lengths1, const at::Tensor& lengths1,
const at::Tensor& lengths2, const at::Tensor& lengths2,
const at::Tensor& idxs, const at::Tensor& idxs,
const int norm,
const at::Tensor& grad_dists); const at::Tensor& grad_dists);
// Implementation which is exposed. // Implementation which is exposed.
@ -120,19 +128,20 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
const at::Tensor& lengths1, const at::Tensor& lengths1,
const at::Tensor& lengths2, const at::Tensor& lengths2,
const at::Tensor& idxs, const at::Tensor& idxs,
const int norm,
const at::Tensor& grad_dists) { const at::Tensor& grad_dists) {
if (p1.is_cuda() || p2.is_cuda()) { if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA #ifdef WITH_CUDA
CHECK_CUDA(p1); CHECK_CUDA(p1);
CHECK_CUDA(p2); CHECK_CUDA(p2);
return KNearestNeighborBackwardCuda( return KNearestNeighborBackwardCuda(
p1, p2, lengths1, lengths2, idxs, grad_dists); p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
#else #else
AT_ERROR("Not compiled with GPU support."); AT_ERROR("Not compiled with GPU support.");
#endif #endif
} }
return KNearestNeighborBackwardCpu( return KNearestNeighborBackwardCpu(
p1, p2, lengths1, lengths2, idxs, grad_dists); p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
} }
// Utility to check whether a KNN version can be used. // Utility to check whether a KNN version can be used.

View File

@ -15,7 +15,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
const at::Tensor& p2, const at::Tensor& p2,
const at::Tensor& lengths1, const at::Tensor& lengths1,
const at::Tensor& lengths2, const at::Tensor& lengths2,
int K) { const int norm,
const int K) {
const int N = p1.size(0); const int N = p1.size(0);
const int P1 = p1.size(1); const int P1 = p1.size(1);
const int D = p1.size(2); const int D = p1.size(2);
@ -41,8 +42,12 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
float dist = 0; float dist = 0;
for (int d = 0; d < D; ++d) { for (int d = 0; d < D; ++d) {
float diff = p1_a[n][i1][d] - p2_a[n][i2][d]; float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
if (norm == 1) {
dist += abs(diff);
} else { // norm is 2 (default)
dist += diff * diff; dist += diff * diff;
} }
}
int size = static_cast<int>(q.size()); int size = static_cast<int>(q.size());
if (size < K || dist < std::get<0>(q.top())) { if (size < K || dist < std::get<0>(q.top())) {
q.emplace(dist, i2); q.emplace(dist, i2);
@ -73,6 +78,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
const at::Tensor& lengths1, const at::Tensor& lengths1,
const at::Tensor& lengths2, const at::Tensor& lengths2,
const at::Tensor& idxs, const at::Tensor& idxs,
const int norm,
const at::Tensor& grad_dists) { const at::Tensor& grad_dists) {
const int N = p1.size(0); const int N = p1.size(0);
const int P1 = p1.size(1); const int P1 = p1.size(1);
@ -104,8 +110,14 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
continue; continue;
} }
for (int64_t d = 0; d < D; ++d) { for (int64_t d = 0; d < D; ++d) {
const float diff = float diff = 0.0;
2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]); if (norm == 1) {
float sign = (p1_a[n][i1][d] > p2_a[n][i2][d]) ? 1.0 : -1.0;
diff = grad_dists_a[n][i1][k] * sign;
} else { // norm is 2 (default)
diff = 2.0f * grad_dists_a[n][i1][k] *
(p1_a[n][i1][d] - p2_a[n][i2][d]);
}
grad_p1_a[n][i1][d] += diff; grad_p1_a[n][i1][d] += diff;
grad_p2_a[n][i2][d] += -1.0f * diff; grad_p2_a[n][i2][d] += -1.0f * diff;
} }

View File

@ -77,6 +77,7 @@ def chamfer_distance(
weights=None, weights=None,
batch_reduction: Union[str, None] = "mean", batch_reduction: Union[str, None] = "mean",
point_reduction: str = "mean", point_reduction: str = "mean",
norm: int = 2,
): ):
""" """
Chamfer distance between two pointclouds x and y. Chamfer distance between two pointclouds x and y.
@ -100,6 +101,7 @@ def chamfer_distance(
batch, can be one of ["mean", "sum"] or None. batch, can be one of ["mean", "sum"] or None.
point_reduction: Reduction operation to apply for the loss across the point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["mean", "sum"]. points, can be one of ["mean", "sum"].
norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
Returns: Returns:
2-element tuple containing 2-element tuple containing
@ -112,6 +114,9 @@ def chamfer_distance(
""" """
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction) _validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
if not ((norm == 1) or (norm == 2)):
raise ValueError("Support for 1 or 2 norm.")
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals) x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals) y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
@ -149,8 +154,8 @@ def chamfer_distance(
cham_norm_x = x.new_zeros(()) cham_norm_x = x.new_zeros(())
cham_norm_y = x.new_zeros(()) cham_norm_y = x.new_zeros(())
x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1) x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, norm=norm, K=1)
y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1) y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, norm=norm, K=1)
cham_x = x_nn.dists[..., 0] # (N, P1) cham_x = x_nn.dists[..., 0] # (N, P1)
cham_y = y_nn.dists[..., 0] # (N, P2) cham_y = y_nn.dists[..., 0] # (N, P2)

View File

@ -43,8 +43,9 @@ class _ball_query(Function):
p2 = p2.float() p2 = p2.float()
# Reuse the KNN backward function # Reuse the KNN backward function
# by default, norm is 2
grad_p1, grad_p2 = _C.knn_points_backward( grad_p1, grad_p2 = _C.knn_points_backward(
p1, p2, lengths1, lengths2, idx, grad_dists p1, p2, lengths1, lengths2, idx, 2, grad_dists
) )
return grad_p1, grad_p2, None, None, None, None return grad_p1, grad_p2, None, None, None, None

View File

@ -24,7 +24,15 @@ class _knn_points(Function):
@staticmethod @staticmethod
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
def forward( def forward(
ctx, p1, p2, lengths1, lengths2, K, version, return_sorted: bool = True ctx,
p1,
p2,
lengths1,
lengths2,
K,
version,
norm: int = 2,
return_sorted: bool = True,
): ):
""" """
K-Nearest neighbors on point clouds. K-Nearest neighbors on point clouds.
@ -43,6 +51,7 @@ class _knn_points(Function):
K: Integer giving the number of nearest neighbors to return. K: Integer giving the number of nearest neighbors to return.
version: Which KNN implementation to use in the backend. If version=-1, version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of the inputs. the correct implementation is selected based on the shapes of the inputs.
norm: (int) indicating the norm. Only supports 1 (for L1) and 2 (for L2).
return_sorted: (bool) whether to return the nearest neighbors sorted in return_sorted: (bool) whether to return the nearest neighbors sorted in
ascending order of distance. ascending order of distance.
@ -57,8 +66,10 @@ class _knn_points(Function):
neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points. in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points.
""" """
if not ((norm == 1) or (norm == 2)):
raise ValueError("Support for 1 or 2 norm.")
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version) idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version)
# sort KNN in ascending order if K > 1 # sort KNN in ascending order if K > 1
if K > 1 and return_sorted: if K > 1 and return_sorted:
@ -78,12 +89,14 @@ class _knn_points(Function):
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx) ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
ctx.mark_non_differentiable(idx) ctx.mark_non_differentiable(idx)
ctx.norm = norm
return dists, idx return dists, idx
@staticmethod @staticmethod
@once_differentiable @once_differentiable
def backward(ctx, grad_dists, grad_idx): def backward(ctx, grad_dists, grad_idx):
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
norm = ctx.norm
# TODO(gkioxari) Change cast to floats once we add support for doubles. # TODO(gkioxari) Change cast to floats once we add support for doubles.
if not (grad_dists.dtype == torch.float32): if not (grad_dists.dtype == torch.float32):
grad_dists = grad_dists.float() grad_dists = grad_dists.float()
@ -92,9 +105,9 @@ class _knn_points(Function):
if not (p2.dtype == torch.float32): if not (p2.dtype == torch.float32):
p2 = p2.float() p2 = p2.float()
grad_p1, grad_p2 = _C.knn_points_backward( grad_p1, grad_p2 = _C.knn_points_backward(
p1, p2, lengths1, lengths2, idx, grad_dists p1, p2, lengths1, lengths2, idx, norm, grad_dists
) )
return grad_p1, grad_p2, None, None, None, None, None return grad_p1, grad_p2, None, None, None, None, None, None
def knn_points( def knn_points(
@ -102,6 +115,7 @@ def knn_points(
p2: torch.Tensor, p2: torch.Tensor,
lengths1: Union[torch.Tensor, None] = None, lengths1: Union[torch.Tensor, None] = None,
lengths2: Union[torch.Tensor, None] = None, lengths2: Union[torch.Tensor, None] = None,
norm: int = 2,
K: int = 1, K: int = 1,
version: int = -1, version: int = -1,
return_nn: bool = False, return_nn: bool = False,
@ -121,6 +135,7 @@ def knn_points(
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
length of each pointcloud in p2. Or None to indicate that every cloud has length of each pointcloud in p2. Or None to indicate that every cloud has
length P2. length P2.
norm: Integer indicating the norm of the distance. Supports only 1 for L1, 2 for L2.
K: Integer giving the number of nearest neighbors to return. K: Integer giving the number of nearest neighbors to return.
version: Which KNN implementation to use in the backend. If version=-1, version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of the inputs. the correct implementation is selected based on the shapes of the inputs.
@ -172,7 +187,7 @@ def knn_points(
# pyre-fixme[16]: `_knn_points` has no attribute `apply`. # pyre-fixme[16]: `_knn_points` has no attribute `apply`.
p1_dists, p1_idx = _knn_points.apply( p1_dists, p1_idx = _knn_points.apply(
p1, p2, lengths1, lengths2, K, version, return_sorted p1, p2, lengths1, lengths2, K, version, norm, return_sorted
) )
p2_nn = None p2_nn = None

View File

@ -87,7 +87,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
) )
@staticmethod @staticmethod
def chamfer_distance_naive_pointclouds(p1, p2, device="cpu"): def chamfer_distance_naive_pointclouds(p1, p2, norm: int = 2, device="cpu"):
""" """
Naive iterative implementation of nearest neighbor and chamfer distance. Naive iterative implementation of nearest neighbor and chamfer distance.
x and y are assumed to be pointclouds objects with points and optionally normals. x and y are assumed to be pointclouds objects with points and optionally normals.
@ -121,7 +121,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
for n in range(N): for n in range(N):
for i1 in range(x_lengths[n]): for i1 in range(x_lengths[n]):
for i2 in range(y_lengths[n]): for i2 in range(y_lengths[n]):
if norm == 2:
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2) dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
elif norm == 1:
dist[n, i1, i2] = torch.sum(
torch.abs(x[n, i1, :] - y[n, i2, :])
)
else:
raise ValueError("No support for norm %d" % (norm))
x_dist = torch.min(dist, dim=2)[0] # (N, P1) x_dist = torch.min(dist, dim=2)[0] # (N, P1)
y_dist = torch.min(dist, dim=1)[0] # (N, P2) y_dist = torch.min(dist, dim=1)[0] # (N, P2)
@ -159,7 +166,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
return loss, lnorm return loss, lnorm
@staticmethod @staticmethod
def chamfer_distance_naive(x, y, x_normals=None, y_normals=None): def chamfer_distance_naive(x, y, x_normals=None, y_normals=None, norm: int = 2):
""" """
Naive iterative implementation of nearest neighbor and chamfer distance. Naive iterative implementation of nearest neighbor and chamfer distance.
Returns lists of the unreduced loss and loss_normals. This naive Returns lists of the unreduced loss and loss_normals. This naive
@ -174,7 +181,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
for n in range(N): for n in range(N):
for i1 in range(P1): for i1 in range(P1):
for i2 in range(P2): for i2 in range(P2):
if norm == 2:
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2) dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
elif norm == 1:
dist[n, i1, i2] = torch.sum(
torch.abs(x[n, i1, :] - y[n, i2, :])
)
else:
raise ValueError("No support for norm %d" % (norm))
loss = [ loss = [
torch.min(dist, dim=2)[0], # (N, P1) torch.min(dist, dim=2)[0], # (N, P1)
@ -208,6 +222,8 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
""" """
N, max_P1, max_P2 = 7, 10, 18 N, max_P1, max_P2 = 7, 10, 18
device = get_random_cuda_device() device = get_random_cuda_device()
for norm in [1, 2]:
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1 p1 = points_normals.p1
p2 = points_normals.p2 p2 = points_normals.p2
@ -219,10 +235,12 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
P1 = p1.shape[1] P1 = p1.shape[1]
P2 = p2.shape[1] P2 = p2.shape[1]
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(p1, p2) pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, norm=norm
)
# point_reduction = "mean". # point_reduction = "mean".
loss, loss_norm = chamfer_distance(p11, p22, weights=weights) loss, loss_norm = chamfer_distance(p11, p22, weights=weights, norm=norm)
pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2 pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
pred_loss *= weights pred_loss *= weights
pred_loss = pred_loss.sum() / weights.sum() pred_loss = pred_loss.sum() / weights.sum()
@ -242,6 +260,8 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
""" """
N, max_P1, max_P2 = 3, 70, 70 N, max_P1, max_P2 = 3, 70, 70
device = get_random_cuda_device() device = get_random_cuda_device()
for norm in [1, 2]:
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
weights = points_normals.weights weights = points_normals.weights
x_lengths = points_normals.p1_lengths x_lengths = points_normals.p1_lengths
@ -256,11 +276,12 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
x_lengths=points_normals.p1_lengths, x_lengths=points_normals.p1_lengths,
y_lengths=points_normals.p2_lengths, y_lengths=points_normals.p2_lengths,
weights=weights, weights=weights,
norm=norm,
) )
# Chamfer with pointclouds as input. # Chamfer with pointclouds as input.
pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds( pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds(
points_normals.cloud1, points_normals.cloud2, device=device points_normals.cloud1, points_normals.cloud2, norm=norm, device=device
) )
# Mean reduction point loss. # Mean reduction point loss.
@ -276,7 +297,8 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
pred_norm_loss[0] *= weights.view(N, 1) pred_norm_loss[0] *= weights.view(N, 1)
pred_norm_loss[1] *= weights.view(N, 1) pred_norm_loss[1] *= weights.view(N, 1)
pred_norm_loss_mean = ( pred_norm_loss_mean = (
pred_norm_loss[0].sum(1) / x_lengths + pred_norm_loss[1].sum(1) / y_lengths pred_norm_loss[0].sum(1) / x_lengths
+ pred_norm_loss[1].sum(1) / y_lengths
) )
pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum() pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum()
@ -742,6 +764,19 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "Pointclouds objects or torch.Tensor"): with self.assertRaisesRegex(ValueError, "Pointclouds objects or torch.Tensor"):
chamfer_distance(x=[1, 1, 1], y=[1, 1, 1]) chamfer_distance(x=[1, 1, 1], y=[1, 1, 1])
def test_invalid_norm(self):
N, P1, P2 = 7, 10, 18
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
chamfer_distance(p1, p2, norm=0)
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
chamfer_distance(p1, p2, norm=3)
@staticmethod @staticmethod
def chamfer_with_init( def chamfer_with_init(
batch_size: int, batch_size: int,

View File

@ -18,7 +18,9 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
torch.manual_seed(1) torch.manual_seed(1)
@staticmethod @staticmethod
def _knn_points_naive(p1, p2, lengths1, lengths2, K: int) -> torch.Tensor: def _knn_points_naive(
p1, p2, lengths1, lengths2, K: int, norm: int = 2
) -> torch.Tensor:
""" """
Naive PyTorch implementation of K-Nearest Neighbors. Naive PyTorch implementation of K-Nearest Neighbors.
Returns always sorted results Returns always sorted results
@ -42,7 +44,12 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
pp1 = p1[n, :num1].view(num1, 1, D) pp1 = p1[n, :num1].view(num1, 1, D)
pp2 = p2[n, :num2].view(1, num2, D) pp2 = p2[n, :num2].view(1, num2, D)
diff = pp1 - pp2 diff = pp1 - pp2
if norm == 2:
diff = (diff * diff).sum(2) diff = (diff * diff).sum(2)
elif norm == 1:
diff = diff.abs().sum(2)
else:
raise ValueError("No support for norm %d" % (norm))
num2 = min(num2, K) num2 = min(num2, K)
for i in range(num1): for i in range(num1):
dd = diff[i] dd = diff[i]
@ -59,9 +66,10 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
P1s = [8, 24] P1s = [8, 24]
P2s = [8, 16, 32] P2s = [8, 16, 32]
Ks = [1, 3, 10] Ks = [1, 3, 10]
norms = [1, 2]
versions = [0, 1, 2, 3] versions = [0, 1, 2, 3]
factors = [Ns, Ds, P1s, P2s, Ks] factors = [Ns, Ds, P1s, P2s, Ks, norms]
for N, D, P1, P2, K in product(*factors): for N, D, P1, P2, K, norm in product(*factors):
for version in versions: for version in versions:
if version == 3 and K > 4: if version == 3 and K > 4:
continue continue
@ -73,9 +81,16 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
y_cuda.requires_grad_(True) y_cuda.requires_grad_(True)
# forward # forward
out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K) out1 = self._knn_points_naive(
x, y, lengths1=None, lengths2=None, K=K, norm=norm
)
out2 = knn_points( out2 = knn_points(
x_cuda, y_cuda, K=K, version=version, return_sorted=return_sorted x_cuda,
y_cuda,
K=K,
norm=norm,
version=version,
return_sorted=return_sorted,
) )
if K > 1 and not return_sorted: if K > 1 and not return_sorted:
# check out2 is not sorted # check out2 is not sorted
@ -121,8 +136,9 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
P1s = [8, 24] P1s = [8, 24]
P2s = [8, 16, 32] P2s = [8, 16, 32]
Ks = [1, 3, 10] Ks = [1, 3, 10]
factors = [Ns, Ds, P1s, P2s, Ks] norms = [1, 2]
for N, D, P1, P2, K in product(*factors): factors = [Ns, Ds, P1s, P2s, Ks, norms]
for N, D, P1, P2, K, norm in product(*factors):
x = torch.rand((N, P1, D), device=device, requires_grad=True) x = torch.rand((N, P1, D), device=device, requires_grad=True)
y = torch.rand((N, P2, D), device=device, requires_grad=True) y = torch.rand((N, P2, D), device=device, requires_grad=True)
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device) lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
@ -135,9 +151,11 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
# forward # forward
out1 = self._knn_points_naive( out1 = self._knn_points_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K x, y, lengths1=lengths1, lengths2=lengths2, K=K, norm=norm
)
out2 = knn_points(
x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K, norm=norm
) )
out2 = knn_points(x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K)
self.assertClose(out1[0], out2[0]) self.assertClose(out1[0], out2[0])
self.assertTrue(torch.all(out1[1] == out2[1])) self.assertTrue(torch.all(out1[1] == out2[1]))
@ -198,6 +216,17 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
expected = all_expected[version] expected = all_expected[version]
self.assertEqual(actual, expected) self.assertEqual(actual, expected)
def test_invalid_norm(self):
device = get_random_cuda_device()
N, P1, P2, K, D = 4, 16, 12, 8, 3
x = torch.rand((N, P1, D), device=device)
y = torch.rand((N, P2, D), device=device)
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
knn_points(x, y, K=K, norm=3)
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
knn_points(x, y, K=K, norm=0)
@staticmethod @staticmethod
def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str): def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str):
device = torch.device(device) device = torch.device(device)