mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-21 14:50:36 +08:00
add L1 support for KNN & Chamfer
Summary: Added L1 norm for KNN and chamfer op * The norm is now specified with a variable `norm` which can only be 1 or 2 Reviewed By: bottler Differential Revision: D35419637 fbshipit-source-id: 77813fec650b30c28342af90d5ed02c89133e136
This commit is contained in:
committed by
Facebook GitHub Bot
parent
4b94649f7b
commit
67fff956a2
@@ -36,7 +36,8 @@ __global__ void KNearestNeighborKernelV0(
|
||||
const size_t P1,
|
||||
const size_t P2,
|
||||
const size_t D,
|
||||
const size_t K) {
|
||||
const size_t K,
|
||||
const size_t norm) {
|
||||
// Store both dists and indices for knn in global memory.
|
||||
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
|
||||
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||
@@ -56,7 +57,8 @@ __global__ void KNearestNeighborKernelV0(
|
||||
scalar_t coord1 = points1[n * P1 * D + p1 * D + d];
|
||||
scalar_t coord2 = points2[n * P2 * D + p2 * D + d];
|
||||
scalar_t diff = coord1 - coord2;
|
||||
dist += diff * diff;
|
||||
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
|
||||
dist += norm_diff;
|
||||
}
|
||||
mink.add(dist, p2);
|
||||
}
|
||||
@@ -74,7 +76,8 @@ __global__ void KNearestNeighborKernelV1(
|
||||
const size_t N,
|
||||
const size_t P1,
|
||||
const size_t P2,
|
||||
const size_t K) {
|
||||
const size_t K,
|
||||
const size_t norm) {
|
||||
// Same idea as the previous version, but hoist D into a template argument
|
||||
// so we can cache the current point in a thread-local array. We still store
|
||||
// the current best K dists and indices in global memory, so this should work
|
||||
@@ -99,7 +102,8 @@ __global__ void KNearestNeighborKernelV1(
|
||||
scalar_t dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d];
|
||||
dist += diff * diff;
|
||||
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
|
||||
dist += norm_diff;
|
||||
}
|
||||
mink.add(dist, p2);
|
||||
}
|
||||
@@ -121,10 +125,11 @@ struct KNearestNeighborV1Functor {
|
||||
const size_t N,
|
||||
const size_t P1,
|
||||
const size_t P2,
|
||||
const size_t K) {
|
||||
const size_t K,
|
||||
const size_t norm) {
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads, 0, stream>>>(
|
||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K);
|
||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K, norm);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -138,7 +143,8 @@ __global__ void KNearestNeighborKernelV2(
|
||||
int64_t* __restrict__ idxs,
|
||||
const int64_t N,
|
||||
const int64_t P1,
|
||||
const int64_t P2) {
|
||||
const int64_t P2,
|
||||
const size_t norm) {
|
||||
// Same general implementation as V2, but also hoist K into a template arg.
|
||||
scalar_t cur_point[D];
|
||||
scalar_t min_dists[K];
|
||||
@@ -161,7 +167,8 @@ __global__ void KNearestNeighborKernelV2(
|
||||
for (int d = 0; d < D; ++d) {
|
||||
int offset = n * P2 * D + p2 * D + d;
|
||||
scalar_t diff = cur_point[d] - points2[offset];
|
||||
dist += diff * diff;
|
||||
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
|
||||
dist += norm_diff;
|
||||
}
|
||||
mink.add(dist, p2);
|
||||
}
|
||||
@@ -186,10 +193,11 @@ struct KNearestNeighborKernelV2Functor {
|
||||
int64_t* __restrict__ idxs,
|
||||
const int64_t N,
|
||||
const int64_t P1,
|
||||
const int64_t P2) {
|
||||
const int64_t P2,
|
||||
const size_t norm) {
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
|
||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
|
||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -203,7 +211,8 @@ __global__ void KNearestNeighborKernelV3(
|
||||
int64_t* __restrict__ idxs,
|
||||
const size_t N,
|
||||
const size_t P1,
|
||||
const size_t P2) {
|
||||
const size_t P2,
|
||||
const size_t norm) {
|
||||
// Same idea as V2, but use register indexing for thread-local arrays.
|
||||
// Enabling sorting for this version leads to huge slowdowns; I suspect
|
||||
// that it forces min_dists into local memory rather than registers.
|
||||
@@ -229,7 +238,8 @@ __global__ void KNearestNeighborKernelV3(
|
||||
for (int d = 0; d < D; ++d) {
|
||||
int offset = n * P2 * D + p2 * D + d;
|
||||
scalar_t diff = cur_point[d] - points2[offset];
|
||||
dist += diff * diff;
|
||||
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
|
||||
dist += norm_diff;
|
||||
}
|
||||
mink.add(dist, p2);
|
||||
}
|
||||
@@ -254,10 +264,11 @@ struct KNearestNeighborKernelV3Functor {
|
||||
int64_t* __restrict__ idxs,
|
||||
const size_t N,
|
||||
const size_t P1,
|
||||
const size_t P2) {
|
||||
const size_t P2,
|
||||
const size_t norm) {
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
|
||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
|
||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -305,7 +316,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K,
|
||||
const int norm,
|
||||
const int K,
|
||||
int version) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
|
||||
@@ -324,6 +336,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
const auto D = p2.size(2);
|
||||
const int64_t K_64 = K;
|
||||
|
||||
TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2.");
|
||||
|
||||
TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension");
|
||||
auto long_dtype = lengths1.options().dtype(at::kLong);
|
||||
auto idxs = at::zeros({N, P1, K}, long_dtype);
|
||||
@@ -366,7 +380,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
P1,
|
||||
P2,
|
||||
D,
|
||||
K);
|
||||
K,
|
||||
norm);
|
||||
}));
|
||||
} else if (version == 1) {
|
||||
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
||||
@@ -387,7 +402,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
N,
|
||||
P1,
|
||||
P2,
|
||||
K);
|
||||
K,
|
||||
norm);
|
||||
}));
|
||||
} else if (version == 2) {
|
||||
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
||||
@@ -410,7 +426,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
idxs.data_ptr<int64_t>(),
|
||||
N,
|
||||
P1,
|
||||
P2);
|
||||
P2,
|
||||
norm);
|
||||
}));
|
||||
} else if (version == 3) {
|
||||
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
||||
@@ -433,7 +450,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
idxs.data_ptr<int64_t>(),
|
||||
N,
|
||||
P1,
|
||||
P2);
|
||||
P2,
|
||||
norm);
|
||||
}));
|
||||
}
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
@@ -459,7 +477,8 @@ __global__ void KNearestNeighborBackwardKernel(
|
||||
const size_t P1,
|
||||
const size_t P2,
|
||||
const size_t K,
|
||||
const size_t D) {
|
||||
const size_t D,
|
||||
const size_t norm) {
|
||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const size_t stride = gridDim.x * blockDim.x;
|
||||
|
||||
@@ -481,8 +500,17 @@ __global__ void KNearestNeighborBackwardKernel(
|
||||
if (p2_idx == -1) {
|
||||
continue;
|
||||
}
|
||||
const float diff = 2.0 * grad_dist *
|
||||
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
|
||||
float diff = 0.0;
|
||||
if (norm == 1) {
|
||||
float sign =
|
||||
(p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d])
|
||||
? 1.0
|
||||
: -1.0;
|
||||
diff = grad_dist * sign;
|
||||
} else { // norm is 2
|
||||
diff = 2.0 * grad_dist *
|
||||
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
|
||||
}
|
||||
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
|
||||
atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
|
||||
}
|
||||
@@ -495,6 +523,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const at::Tensor& idxs,
|
||||
int norm,
|
||||
const at::Tensor& grad_dists) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
|
||||
@@ -547,7 +576,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
|
||||
P1,
|
||||
P2,
|
||||
K,
|
||||
D);
|
||||
D,
|
||||
norm);
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_p1, grad_p2);
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
// containing P2 points of dimension D.
|
||||
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
|
||||
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
|
||||
// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
|
||||
// K: int giving the number of nearest points to return.
|
||||
// version: Integer telling which implementation to use.
|
||||
//
|
||||
@@ -41,7 +42,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K);
|
||||
const int norm,
|
||||
const int K);
|
||||
|
||||
// CUDA implementation
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
@@ -49,8 +51,9 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K,
|
||||
int version);
|
||||
const int norm,
|
||||
const int K,
|
||||
const int version);
|
||||
|
||||
// Implementation which is exposed.
|
||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
||||
@@ -58,18 +61,20 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K,
|
||||
int version) {
|
||||
const int norm,
|
||||
const int K,
|
||||
const int version) {
|
||||
if (p1.is_cuda() || p2.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(p1);
|
||||
CHECK_CUDA(p2);
|
||||
return KNearestNeighborIdxCuda(p1, p2, lengths1, lengths2, K, version);
|
||||
return KNearestNeighborIdxCuda(
|
||||
p1, p2, lengths1, lengths2, norm, K, version);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K);
|
||||
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K);
|
||||
}
|
||||
|
||||
// Compute gradients with respect to p1 and p2
|
||||
@@ -86,6 +91,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
||||
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
|
||||
// It is padded with zeros so that it can be used easily in a later
|
||||
// gather() operation. This is computed from the forward pass.
|
||||
// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
|
||||
// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input
|
||||
// gradients.
|
||||
//
|
||||
@@ -102,6 +108,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const at::Tensor& idxs,
|
||||
const int norm,
|
||||
const at::Tensor& grad_dists);
|
||||
|
||||
// CUDA implementation
|
||||
@@ -111,6 +118,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const at::Tensor& idxs,
|
||||
const int norm,
|
||||
const at::Tensor& grad_dists);
|
||||
|
||||
// Implementation which is exposed.
|
||||
@@ -120,19 +128,20 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const at::Tensor& idxs,
|
||||
const int norm,
|
||||
const at::Tensor& grad_dists) {
|
||||
if (p1.is_cuda() || p2.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(p1);
|
||||
CHECK_CUDA(p2);
|
||||
return KNearestNeighborBackwardCuda(
|
||||
p1, p2, lengths1, lengths2, idxs, grad_dists);
|
||||
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
return KNearestNeighborBackwardCpu(
|
||||
p1, p2, lengths1, lengths2, idxs, grad_dists);
|
||||
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
|
||||
}
|
||||
|
||||
// Utility to check whether a KNN version can be used.
|
||||
|
||||
@@ -15,7 +15,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
|
||||
const at::Tensor& p2,
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K) {
|
||||
const int norm,
|
||||
const int K) {
|
||||
const int N = p1.size(0);
|
||||
const int P1 = p1.size(1);
|
||||
const int D = p1.size(2);
|
||||
@@ -41,7 +42,11 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
|
||||
float dist = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
|
||||
dist += diff * diff;
|
||||
if (norm == 1) {
|
||||
dist += abs(diff);
|
||||
} else { // norm is 2 (default)
|
||||
dist += diff * diff;
|
||||
}
|
||||
}
|
||||
int size = static_cast<int>(q.size());
|
||||
if (size < K || dist < std::get<0>(q.top())) {
|
||||
@@ -73,6 +78,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const at::Tensor& idxs,
|
||||
const int norm,
|
||||
const at::Tensor& grad_dists) {
|
||||
const int N = p1.size(0);
|
||||
const int P1 = p1.size(1);
|
||||
@@ -104,8 +110,14 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
|
||||
continue;
|
||||
}
|
||||
for (int64_t d = 0; d < D; ++d) {
|
||||
const float diff =
|
||||
2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]);
|
||||
float diff = 0.0;
|
||||
if (norm == 1) {
|
||||
float sign = (p1_a[n][i1][d] > p2_a[n][i2][d]) ? 1.0 : -1.0;
|
||||
diff = grad_dists_a[n][i1][k] * sign;
|
||||
} else { // norm is 2 (default)
|
||||
diff = 2.0f * grad_dists_a[n][i1][k] *
|
||||
(p1_a[n][i1][d] - p2_a[n][i2][d]);
|
||||
}
|
||||
grad_p1_a[n][i1][d] += diff;
|
||||
grad_p2_a[n][i2][d] += -1.0f * diff;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user