mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
add L1 support for KNN & Chamfer
Summary: Added L1 norm for KNN and chamfer op * The norm is now specified with a variable `norm` which can only be 1 or 2 Reviewed By: bottler Differential Revision: D35419637 fbshipit-source-id: 77813fec650b30c28342af90d5ed02c89133e136
This commit is contained in:
parent
4b94649f7b
commit
67fff956a2
@ -36,7 +36,8 @@ __global__ void KNearestNeighborKernelV0(
|
|||||||
const size_t P1,
|
const size_t P1,
|
||||||
const size_t P2,
|
const size_t P2,
|
||||||
const size_t D,
|
const size_t D,
|
||||||
const size_t K) {
|
const size_t K,
|
||||||
|
const size_t norm) {
|
||||||
// Store both dists and indices for knn in global memory.
|
// Store both dists and indices for knn in global memory.
|
||||||
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
|
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
|
||||||
const int64_t chunks_to_do = N * chunks_per_cloud;
|
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||||
@ -56,7 +57,8 @@ __global__ void KNearestNeighborKernelV0(
|
|||||||
scalar_t coord1 = points1[n * P1 * D + p1 * D + d];
|
scalar_t coord1 = points1[n * P1 * D + p1 * D + d];
|
||||||
scalar_t coord2 = points2[n * P2 * D + p2 * D + d];
|
scalar_t coord2 = points2[n * P2 * D + p2 * D + d];
|
||||||
scalar_t diff = coord1 - coord2;
|
scalar_t diff = coord1 - coord2;
|
||||||
dist += diff * diff;
|
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
|
||||||
|
dist += norm_diff;
|
||||||
}
|
}
|
||||||
mink.add(dist, p2);
|
mink.add(dist, p2);
|
||||||
}
|
}
|
||||||
@ -74,7 +76,8 @@ __global__ void KNearestNeighborKernelV1(
|
|||||||
const size_t N,
|
const size_t N,
|
||||||
const size_t P1,
|
const size_t P1,
|
||||||
const size_t P2,
|
const size_t P2,
|
||||||
const size_t K) {
|
const size_t K,
|
||||||
|
const size_t norm) {
|
||||||
// Same idea as the previous version, but hoist D into a template argument
|
// Same idea as the previous version, but hoist D into a template argument
|
||||||
// so we can cache the current point in a thread-local array. We still store
|
// so we can cache the current point in a thread-local array. We still store
|
||||||
// the current best K dists and indices in global memory, so this should work
|
// the current best K dists and indices in global memory, so this should work
|
||||||
@ -99,7 +102,8 @@ __global__ void KNearestNeighborKernelV1(
|
|||||||
scalar_t dist = 0;
|
scalar_t dist = 0;
|
||||||
for (int d = 0; d < D; ++d) {
|
for (int d = 0; d < D; ++d) {
|
||||||
scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d];
|
scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d];
|
||||||
dist += diff * diff;
|
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
|
||||||
|
dist += norm_diff;
|
||||||
}
|
}
|
||||||
mink.add(dist, p2);
|
mink.add(dist, p2);
|
||||||
}
|
}
|
||||||
@ -121,10 +125,11 @@ struct KNearestNeighborV1Functor {
|
|||||||
const size_t N,
|
const size_t N,
|
||||||
const size_t P1,
|
const size_t P1,
|
||||||
const size_t P2,
|
const size_t P2,
|
||||||
const size_t K) {
|
const size_t K,
|
||||||
|
const size_t norm) {
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads, 0, stream>>>(
|
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads, 0, stream>>>(
|
||||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K);
|
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K, norm);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -138,7 +143,8 @@ __global__ void KNearestNeighborKernelV2(
|
|||||||
int64_t* __restrict__ idxs,
|
int64_t* __restrict__ idxs,
|
||||||
const int64_t N,
|
const int64_t N,
|
||||||
const int64_t P1,
|
const int64_t P1,
|
||||||
const int64_t P2) {
|
const int64_t P2,
|
||||||
|
const size_t norm) {
|
||||||
// Same general implementation as V2, but also hoist K into a template arg.
|
// Same general implementation as V2, but also hoist K into a template arg.
|
||||||
scalar_t cur_point[D];
|
scalar_t cur_point[D];
|
||||||
scalar_t min_dists[K];
|
scalar_t min_dists[K];
|
||||||
@ -161,7 +167,8 @@ __global__ void KNearestNeighborKernelV2(
|
|||||||
for (int d = 0; d < D; ++d) {
|
for (int d = 0; d < D; ++d) {
|
||||||
int offset = n * P2 * D + p2 * D + d;
|
int offset = n * P2 * D + p2 * D + d;
|
||||||
scalar_t diff = cur_point[d] - points2[offset];
|
scalar_t diff = cur_point[d] - points2[offset];
|
||||||
dist += diff * diff;
|
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
|
||||||
|
dist += norm_diff;
|
||||||
}
|
}
|
||||||
mink.add(dist, p2);
|
mink.add(dist, p2);
|
||||||
}
|
}
|
||||||
@ -186,10 +193,11 @@ struct KNearestNeighborKernelV2Functor {
|
|||||||
int64_t* __restrict__ idxs,
|
int64_t* __restrict__ idxs,
|
||||||
const int64_t N,
|
const int64_t N,
|
||||||
const int64_t P1,
|
const int64_t P1,
|
||||||
const int64_t P2) {
|
const int64_t P2,
|
||||||
|
const size_t norm) {
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
|
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
|
||||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
|
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -203,7 +211,8 @@ __global__ void KNearestNeighborKernelV3(
|
|||||||
int64_t* __restrict__ idxs,
|
int64_t* __restrict__ idxs,
|
||||||
const size_t N,
|
const size_t N,
|
||||||
const size_t P1,
|
const size_t P1,
|
||||||
const size_t P2) {
|
const size_t P2,
|
||||||
|
const size_t norm) {
|
||||||
// Same idea as V2, but use register indexing for thread-local arrays.
|
// Same idea as V2, but use register indexing for thread-local arrays.
|
||||||
// Enabling sorting for this version leads to huge slowdowns; I suspect
|
// Enabling sorting for this version leads to huge slowdowns; I suspect
|
||||||
// that it forces min_dists into local memory rather than registers.
|
// that it forces min_dists into local memory rather than registers.
|
||||||
@ -229,7 +238,8 @@ __global__ void KNearestNeighborKernelV3(
|
|||||||
for (int d = 0; d < D; ++d) {
|
for (int d = 0; d < D; ++d) {
|
||||||
int offset = n * P2 * D + p2 * D + d;
|
int offset = n * P2 * D + p2 * D + d;
|
||||||
scalar_t diff = cur_point[d] - points2[offset];
|
scalar_t diff = cur_point[d] - points2[offset];
|
||||||
dist += diff * diff;
|
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
|
||||||
|
dist += norm_diff;
|
||||||
}
|
}
|
||||||
mink.add(dist, p2);
|
mink.add(dist, p2);
|
||||||
}
|
}
|
||||||
@ -254,10 +264,11 @@ struct KNearestNeighborKernelV3Functor {
|
|||||||
int64_t* __restrict__ idxs,
|
int64_t* __restrict__ idxs,
|
||||||
const size_t N,
|
const size_t N,
|
||||||
const size_t P1,
|
const size_t P1,
|
||||||
const size_t P2) {
|
const size_t P2,
|
||||||
|
const size_t norm) {
|
||||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||||
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
|
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
|
||||||
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
|
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -305,7 +316,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
const at::Tensor& p2,
|
const at::Tensor& p2,
|
||||||
const at::Tensor& lengths1,
|
const at::Tensor& lengths1,
|
||||||
const at::Tensor& lengths2,
|
const at::Tensor& lengths2,
|
||||||
int K,
|
const int norm,
|
||||||
|
const int K,
|
||||||
int version) {
|
int version) {
|
||||||
// Check inputs are on the same device
|
// Check inputs are on the same device
|
||||||
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
|
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
|
||||||
@ -324,6 +336,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
const auto D = p2.size(2);
|
const auto D = p2.size(2);
|
||||||
const int64_t K_64 = K;
|
const int64_t K_64 = K;
|
||||||
|
|
||||||
|
TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2.");
|
||||||
|
|
||||||
TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension");
|
TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension");
|
||||||
auto long_dtype = lengths1.options().dtype(at::kLong);
|
auto long_dtype = lengths1.options().dtype(at::kLong);
|
||||||
auto idxs = at::zeros({N, P1, K}, long_dtype);
|
auto idxs = at::zeros({N, P1, K}, long_dtype);
|
||||||
@ -366,7 +380,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
P1,
|
P1,
|
||||||
P2,
|
P2,
|
||||||
D,
|
D,
|
||||||
K);
|
K,
|
||||||
|
norm);
|
||||||
}));
|
}));
|
||||||
} else if (version == 1) {
|
} else if (version == 1) {
|
||||||
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
||||||
@ -387,7 +402,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
N,
|
N,
|
||||||
P1,
|
P1,
|
||||||
P2,
|
P2,
|
||||||
K);
|
K,
|
||||||
|
norm);
|
||||||
}));
|
}));
|
||||||
} else if (version == 2) {
|
} else if (version == 2) {
|
||||||
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
||||||
@ -410,7 +426,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
idxs.data_ptr<int64_t>(),
|
idxs.data_ptr<int64_t>(),
|
||||||
N,
|
N,
|
||||||
P1,
|
P1,
|
||||||
P2);
|
P2,
|
||||||
|
norm);
|
||||||
}));
|
}));
|
||||||
} else if (version == 3) {
|
} else if (version == 3) {
|
||||||
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
|
||||||
@ -433,7 +450,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
idxs.data_ptr<int64_t>(),
|
idxs.data_ptr<int64_t>(),
|
||||||
N,
|
N,
|
||||||
P1,
|
P1,
|
||||||
P2);
|
P2,
|
||||||
|
norm);
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
@ -459,7 +477,8 @@ __global__ void KNearestNeighborBackwardKernel(
|
|||||||
const size_t P1,
|
const size_t P1,
|
||||||
const size_t P2,
|
const size_t P2,
|
||||||
const size_t K,
|
const size_t K,
|
||||||
const size_t D) {
|
const size_t D,
|
||||||
|
const size_t norm) {
|
||||||
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const size_t stride = gridDim.x * blockDim.x;
|
const size_t stride = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
@ -481,8 +500,17 @@ __global__ void KNearestNeighborBackwardKernel(
|
|||||||
if (p2_idx == -1) {
|
if (p2_idx == -1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
const float diff = 2.0 * grad_dist *
|
float diff = 0.0;
|
||||||
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
|
if (norm == 1) {
|
||||||
|
float sign =
|
||||||
|
(p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d])
|
||||||
|
? 1.0
|
||||||
|
: -1.0;
|
||||||
|
diff = grad_dist * sign;
|
||||||
|
} else { // norm is 2
|
||||||
|
diff = 2.0 * grad_dist *
|
||||||
|
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
|
||||||
|
}
|
||||||
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
|
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
|
||||||
atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
|
atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
|
||||||
}
|
}
|
||||||
@ -495,6 +523,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
|
|||||||
const at::Tensor& lengths1,
|
const at::Tensor& lengths1,
|
||||||
const at::Tensor& lengths2,
|
const at::Tensor& lengths2,
|
||||||
const at::Tensor& idxs,
|
const at::Tensor& idxs,
|
||||||
|
int norm,
|
||||||
const at::Tensor& grad_dists) {
|
const at::Tensor& grad_dists) {
|
||||||
// Check inputs are on the same device
|
// Check inputs are on the same device
|
||||||
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
|
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
|
||||||
@ -547,7 +576,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
|
|||||||
P1,
|
P1,
|
||||||
P2,
|
P2,
|
||||||
K,
|
K,
|
||||||
D);
|
D,
|
||||||
|
norm);
|
||||||
|
|
||||||
AT_CUDA_CHECK(cudaGetLastError());
|
AT_CUDA_CHECK(cudaGetLastError());
|
||||||
return std::make_tuple(grad_p1, grad_p2);
|
return std::make_tuple(grad_p1, grad_p2);
|
||||||
|
@ -21,6 +21,7 @@
|
|||||||
// containing P2 points of dimension D.
|
// containing P2 points of dimension D.
|
||||||
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
|
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
|
||||||
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
|
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
|
||||||
|
// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
|
||||||
// K: int giving the number of nearest points to return.
|
// K: int giving the number of nearest points to return.
|
||||||
// version: Integer telling which implementation to use.
|
// version: Integer telling which implementation to use.
|
||||||
//
|
//
|
||||||
@ -41,7 +42,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
|
|||||||
const at::Tensor& p2,
|
const at::Tensor& p2,
|
||||||
const at::Tensor& lengths1,
|
const at::Tensor& lengths1,
|
||||||
const at::Tensor& lengths2,
|
const at::Tensor& lengths2,
|
||||||
int K);
|
const int norm,
|
||||||
|
const int K);
|
||||||
|
|
||||||
// CUDA implementation
|
// CUDA implementation
|
||||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
||||||
@ -49,8 +51,9 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
|
|||||||
const at::Tensor& p2,
|
const at::Tensor& p2,
|
||||||
const at::Tensor& lengths1,
|
const at::Tensor& lengths1,
|
||||||
const at::Tensor& lengths2,
|
const at::Tensor& lengths2,
|
||||||
int K,
|
const int norm,
|
||||||
int version);
|
const int K,
|
||||||
|
const int version);
|
||||||
|
|
||||||
// Implementation which is exposed.
|
// Implementation which is exposed.
|
||||||
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
||||||
@ -58,18 +61,20 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
|||||||
const at::Tensor& p2,
|
const at::Tensor& p2,
|
||||||
const at::Tensor& lengths1,
|
const at::Tensor& lengths1,
|
||||||
const at::Tensor& lengths2,
|
const at::Tensor& lengths2,
|
||||||
int K,
|
const int norm,
|
||||||
int version) {
|
const int K,
|
||||||
|
const int version) {
|
||||||
if (p1.is_cuda() || p2.is_cuda()) {
|
if (p1.is_cuda() || p2.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CUDA(p1);
|
CHECK_CUDA(p1);
|
||||||
CHECK_CUDA(p2);
|
CHECK_CUDA(p2);
|
||||||
return KNearestNeighborIdxCuda(p1, p2, lengths1, lengths2, K, version);
|
return KNearestNeighborIdxCuda(
|
||||||
|
p1, p2, lengths1, lengths2, norm, K, version);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K);
|
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute gradients with respect to p1 and p2
|
// Compute gradients with respect to p1 and p2
|
||||||
@ -86,6 +91,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
|||||||
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
|
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
|
||||||
// It is padded with zeros so that it can be used easily in a later
|
// It is padded with zeros so that it can be used easily in a later
|
||||||
// gather() operation. This is computed from the forward pass.
|
// gather() operation. This is computed from the forward pass.
|
||||||
|
// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
|
||||||
// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input
|
// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input
|
||||||
// gradients.
|
// gradients.
|
||||||
//
|
//
|
||||||
@ -102,6 +108,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
|
|||||||
const at::Tensor& lengths1,
|
const at::Tensor& lengths1,
|
||||||
const at::Tensor& lengths2,
|
const at::Tensor& lengths2,
|
||||||
const at::Tensor& idxs,
|
const at::Tensor& idxs,
|
||||||
|
const int norm,
|
||||||
const at::Tensor& grad_dists);
|
const at::Tensor& grad_dists);
|
||||||
|
|
||||||
// CUDA implementation
|
// CUDA implementation
|
||||||
@ -111,6 +118,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
|
|||||||
const at::Tensor& lengths1,
|
const at::Tensor& lengths1,
|
||||||
const at::Tensor& lengths2,
|
const at::Tensor& lengths2,
|
||||||
const at::Tensor& idxs,
|
const at::Tensor& idxs,
|
||||||
|
const int norm,
|
||||||
const at::Tensor& grad_dists);
|
const at::Tensor& grad_dists);
|
||||||
|
|
||||||
// Implementation which is exposed.
|
// Implementation which is exposed.
|
||||||
@ -120,19 +128,20 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
|
|||||||
const at::Tensor& lengths1,
|
const at::Tensor& lengths1,
|
||||||
const at::Tensor& lengths2,
|
const at::Tensor& lengths2,
|
||||||
const at::Tensor& idxs,
|
const at::Tensor& idxs,
|
||||||
|
const int norm,
|
||||||
const at::Tensor& grad_dists) {
|
const at::Tensor& grad_dists) {
|
||||||
if (p1.is_cuda() || p2.is_cuda()) {
|
if (p1.is_cuda() || p2.is_cuda()) {
|
||||||
#ifdef WITH_CUDA
|
#ifdef WITH_CUDA
|
||||||
CHECK_CUDA(p1);
|
CHECK_CUDA(p1);
|
||||||
CHECK_CUDA(p2);
|
CHECK_CUDA(p2);
|
||||||
return KNearestNeighborBackwardCuda(
|
return KNearestNeighborBackwardCuda(
|
||||||
p1, p2, lengths1, lengths2, idxs, grad_dists);
|
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
|
||||||
#else
|
#else
|
||||||
AT_ERROR("Not compiled with GPU support.");
|
AT_ERROR("Not compiled with GPU support.");
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
return KNearestNeighborBackwardCpu(
|
return KNearestNeighborBackwardCpu(
|
||||||
p1, p2, lengths1, lengths2, idxs, grad_dists);
|
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Utility to check whether a KNN version can be used.
|
// Utility to check whether a KNN version can be used.
|
||||||
|
@ -15,7 +15,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
|
|||||||
const at::Tensor& p2,
|
const at::Tensor& p2,
|
||||||
const at::Tensor& lengths1,
|
const at::Tensor& lengths1,
|
||||||
const at::Tensor& lengths2,
|
const at::Tensor& lengths2,
|
||||||
int K) {
|
const int norm,
|
||||||
|
const int K) {
|
||||||
const int N = p1.size(0);
|
const int N = p1.size(0);
|
||||||
const int P1 = p1.size(1);
|
const int P1 = p1.size(1);
|
||||||
const int D = p1.size(2);
|
const int D = p1.size(2);
|
||||||
@ -41,7 +42,11 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
|
|||||||
float dist = 0;
|
float dist = 0;
|
||||||
for (int d = 0; d < D; ++d) {
|
for (int d = 0; d < D; ++d) {
|
||||||
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
|
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
|
||||||
dist += diff * diff;
|
if (norm == 1) {
|
||||||
|
dist += abs(diff);
|
||||||
|
} else { // norm is 2 (default)
|
||||||
|
dist += diff * diff;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
int size = static_cast<int>(q.size());
|
int size = static_cast<int>(q.size());
|
||||||
if (size < K || dist < std::get<0>(q.top())) {
|
if (size < K || dist < std::get<0>(q.top())) {
|
||||||
@ -73,6 +78,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
|
|||||||
const at::Tensor& lengths1,
|
const at::Tensor& lengths1,
|
||||||
const at::Tensor& lengths2,
|
const at::Tensor& lengths2,
|
||||||
const at::Tensor& idxs,
|
const at::Tensor& idxs,
|
||||||
|
const int norm,
|
||||||
const at::Tensor& grad_dists) {
|
const at::Tensor& grad_dists) {
|
||||||
const int N = p1.size(0);
|
const int N = p1.size(0);
|
||||||
const int P1 = p1.size(1);
|
const int P1 = p1.size(1);
|
||||||
@ -104,8 +110,14 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
for (int64_t d = 0; d < D; ++d) {
|
for (int64_t d = 0; d < D; ++d) {
|
||||||
const float diff =
|
float diff = 0.0;
|
||||||
2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]);
|
if (norm == 1) {
|
||||||
|
float sign = (p1_a[n][i1][d] > p2_a[n][i2][d]) ? 1.0 : -1.0;
|
||||||
|
diff = grad_dists_a[n][i1][k] * sign;
|
||||||
|
} else { // norm is 2 (default)
|
||||||
|
diff = 2.0f * grad_dists_a[n][i1][k] *
|
||||||
|
(p1_a[n][i1][d] - p2_a[n][i2][d]);
|
||||||
|
}
|
||||||
grad_p1_a[n][i1][d] += diff;
|
grad_p1_a[n][i1][d] += diff;
|
||||||
grad_p2_a[n][i2][d] += -1.0f * diff;
|
grad_p2_a[n][i2][d] += -1.0f * diff;
|
||||||
}
|
}
|
||||||
|
@ -77,6 +77,7 @@ def chamfer_distance(
|
|||||||
weights=None,
|
weights=None,
|
||||||
batch_reduction: Union[str, None] = "mean",
|
batch_reduction: Union[str, None] = "mean",
|
||||||
point_reduction: str = "mean",
|
point_reduction: str = "mean",
|
||||||
|
norm: int = 2,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Chamfer distance between two pointclouds x and y.
|
Chamfer distance between two pointclouds x and y.
|
||||||
@ -100,6 +101,7 @@ def chamfer_distance(
|
|||||||
batch, can be one of ["mean", "sum"] or None.
|
batch, can be one of ["mean", "sum"] or None.
|
||||||
point_reduction: Reduction operation to apply for the loss across the
|
point_reduction: Reduction operation to apply for the loss across the
|
||||||
points, can be one of ["mean", "sum"].
|
points, can be one of ["mean", "sum"].
|
||||||
|
norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
2-element tuple containing
|
2-element tuple containing
|
||||||
@ -112,6 +114,9 @@ def chamfer_distance(
|
|||||||
"""
|
"""
|
||||||
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
|
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
|
||||||
|
|
||||||
|
if not ((norm == 1) or (norm == 2)):
|
||||||
|
raise ValueError("Support for 1 or 2 norm.")
|
||||||
|
|
||||||
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
|
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
|
||||||
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
|
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
|
||||||
|
|
||||||
@ -149,8 +154,8 @@ def chamfer_distance(
|
|||||||
cham_norm_x = x.new_zeros(())
|
cham_norm_x = x.new_zeros(())
|
||||||
cham_norm_y = x.new_zeros(())
|
cham_norm_y = x.new_zeros(())
|
||||||
|
|
||||||
x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1)
|
x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, norm=norm, K=1)
|
||||||
y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1)
|
y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, norm=norm, K=1)
|
||||||
|
|
||||||
cham_x = x_nn.dists[..., 0] # (N, P1)
|
cham_x = x_nn.dists[..., 0] # (N, P1)
|
||||||
cham_y = y_nn.dists[..., 0] # (N, P2)
|
cham_y = y_nn.dists[..., 0] # (N, P2)
|
||||||
|
@ -43,8 +43,9 @@ class _ball_query(Function):
|
|||||||
p2 = p2.float()
|
p2 = p2.float()
|
||||||
|
|
||||||
# Reuse the KNN backward function
|
# Reuse the KNN backward function
|
||||||
|
# by default, norm is 2
|
||||||
grad_p1, grad_p2 = _C.knn_points_backward(
|
grad_p1, grad_p2 = _C.knn_points_backward(
|
||||||
p1, p2, lengths1, lengths2, idx, grad_dists
|
p1, p2, lengths1, lengths2, idx, 2, grad_dists
|
||||||
)
|
)
|
||||||
return grad_p1, grad_p2, None, None, None, None
|
return grad_p1, grad_p2, None, None, None, None
|
||||||
|
|
||||||
|
@ -24,7 +24,15 @@ class _knn_points(Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
|
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
|
||||||
def forward(
|
def forward(
|
||||||
ctx, p1, p2, lengths1, lengths2, K, version, return_sorted: bool = True
|
ctx,
|
||||||
|
p1,
|
||||||
|
p2,
|
||||||
|
lengths1,
|
||||||
|
lengths2,
|
||||||
|
K,
|
||||||
|
version,
|
||||||
|
norm: int = 2,
|
||||||
|
return_sorted: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
K-Nearest neighbors on point clouds.
|
K-Nearest neighbors on point clouds.
|
||||||
@ -43,6 +51,7 @@ class _knn_points(Function):
|
|||||||
K: Integer giving the number of nearest neighbors to return.
|
K: Integer giving the number of nearest neighbors to return.
|
||||||
version: Which KNN implementation to use in the backend. If version=-1,
|
version: Which KNN implementation to use in the backend. If version=-1,
|
||||||
the correct implementation is selected based on the shapes of the inputs.
|
the correct implementation is selected based on the shapes of the inputs.
|
||||||
|
norm: (int) indicating the norm. Only supports 1 (for L1) and 2 (for L2).
|
||||||
return_sorted: (bool) whether to return the nearest neighbors sorted in
|
return_sorted: (bool) whether to return the nearest neighbors sorted in
|
||||||
ascending order of distance.
|
ascending order of distance.
|
||||||
|
|
||||||
@ -57,8 +66,10 @@ class _knn_points(Function):
|
|||||||
neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
|
neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
|
||||||
in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points.
|
in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points.
|
||||||
"""
|
"""
|
||||||
|
if not ((norm == 1) or (norm == 2)):
|
||||||
|
raise ValueError("Support for 1 or 2 norm.")
|
||||||
|
|
||||||
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version)
|
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version)
|
||||||
|
|
||||||
# sort KNN in ascending order if K > 1
|
# sort KNN in ascending order if K > 1
|
||||||
if K > 1 and return_sorted:
|
if K > 1 and return_sorted:
|
||||||
@ -78,12 +89,14 @@ class _knn_points(Function):
|
|||||||
|
|
||||||
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
|
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
|
||||||
ctx.mark_non_differentiable(idx)
|
ctx.mark_non_differentiable(idx)
|
||||||
|
ctx.norm = norm
|
||||||
return dists, idx
|
return dists, idx
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@once_differentiable
|
@once_differentiable
|
||||||
def backward(ctx, grad_dists, grad_idx):
|
def backward(ctx, grad_dists, grad_idx):
|
||||||
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
|
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
|
||||||
|
norm = ctx.norm
|
||||||
# TODO(gkioxari) Change cast to floats once we add support for doubles.
|
# TODO(gkioxari) Change cast to floats once we add support for doubles.
|
||||||
if not (grad_dists.dtype == torch.float32):
|
if not (grad_dists.dtype == torch.float32):
|
||||||
grad_dists = grad_dists.float()
|
grad_dists = grad_dists.float()
|
||||||
@ -92,9 +105,9 @@ class _knn_points(Function):
|
|||||||
if not (p2.dtype == torch.float32):
|
if not (p2.dtype == torch.float32):
|
||||||
p2 = p2.float()
|
p2 = p2.float()
|
||||||
grad_p1, grad_p2 = _C.knn_points_backward(
|
grad_p1, grad_p2 = _C.knn_points_backward(
|
||||||
p1, p2, lengths1, lengths2, idx, grad_dists
|
p1, p2, lengths1, lengths2, idx, norm, grad_dists
|
||||||
)
|
)
|
||||||
return grad_p1, grad_p2, None, None, None, None, None
|
return grad_p1, grad_p2, None, None, None, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def knn_points(
|
def knn_points(
|
||||||
@ -102,6 +115,7 @@ def knn_points(
|
|||||||
p2: torch.Tensor,
|
p2: torch.Tensor,
|
||||||
lengths1: Union[torch.Tensor, None] = None,
|
lengths1: Union[torch.Tensor, None] = None,
|
||||||
lengths2: Union[torch.Tensor, None] = None,
|
lengths2: Union[torch.Tensor, None] = None,
|
||||||
|
norm: int = 2,
|
||||||
K: int = 1,
|
K: int = 1,
|
||||||
version: int = -1,
|
version: int = -1,
|
||||||
return_nn: bool = False,
|
return_nn: bool = False,
|
||||||
@ -121,6 +135,7 @@ def knn_points(
|
|||||||
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
|
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
|
||||||
length of each pointcloud in p2. Or None to indicate that every cloud has
|
length of each pointcloud in p2. Or None to indicate that every cloud has
|
||||||
length P2.
|
length P2.
|
||||||
|
norm: Integer indicating the norm of the distance. Supports only 1 for L1, 2 for L2.
|
||||||
K: Integer giving the number of nearest neighbors to return.
|
K: Integer giving the number of nearest neighbors to return.
|
||||||
version: Which KNN implementation to use in the backend. If version=-1,
|
version: Which KNN implementation to use in the backend. If version=-1,
|
||||||
the correct implementation is selected based on the shapes of the inputs.
|
the correct implementation is selected based on the shapes of the inputs.
|
||||||
@ -172,7 +187,7 @@ def knn_points(
|
|||||||
|
|
||||||
# pyre-fixme[16]: `_knn_points` has no attribute `apply`.
|
# pyre-fixme[16]: `_knn_points` has no attribute `apply`.
|
||||||
p1_dists, p1_idx = _knn_points.apply(
|
p1_dists, p1_idx = _knn_points.apply(
|
||||||
p1, p2, lengths1, lengths2, K, version, return_sorted
|
p1, p2, lengths1, lengths2, K, version, norm, return_sorted
|
||||||
)
|
)
|
||||||
|
|
||||||
p2_nn = None
|
p2_nn = None
|
||||||
|
@ -87,7 +87,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def chamfer_distance_naive_pointclouds(p1, p2, device="cpu"):
|
def chamfer_distance_naive_pointclouds(p1, p2, norm: int = 2, device="cpu"):
|
||||||
"""
|
"""
|
||||||
Naive iterative implementation of nearest neighbor and chamfer distance.
|
Naive iterative implementation of nearest neighbor and chamfer distance.
|
||||||
x and y are assumed to be pointclouds objects with points and optionally normals.
|
x and y are assumed to be pointclouds objects with points and optionally normals.
|
||||||
@ -121,7 +121,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
|||||||
for n in range(N):
|
for n in range(N):
|
||||||
for i1 in range(x_lengths[n]):
|
for i1 in range(x_lengths[n]):
|
||||||
for i2 in range(y_lengths[n]):
|
for i2 in range(y_lengths[n]):
|
||||||
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
|
if norm == 2:
|
||||||
|
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
|
||||||
|
elif norm == 1:
|
||||||
|
dist[n, i1, i2] = torch.sum(
|
||||||
|
torch.abs(x[n, i1, :] - y[n, i2, :])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("No support for norm %d" % (norm))
|
||||||
|
|
||||||
x_dist = torch.min(dist, dim=2)[0] # (N, P1)
|
x_dist = torch.min(dist, dim=2)[0] # (N, P1)
|
||||||
y_dist = torch.min(dist, dim=1)[0] # (N, P2)
|
y_dist = torch.min(dist, dim=1)[0] # (N, P2)
|
||||||
@ -159,7 +166,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
|||||||
return loss, lnorm
|
return loss, lnorm
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def chamfer_distance_naive(x, y, x_normals=None, y_normals=None):
|
def chamfer_distance_naive(x, y, x_normals=None, y_normals=None, norm: int = 2):
|
||||||
"""
|
"""
|
||||||
Naive iterative implementation of nearest neighbor and chamfer distance.
|
Naive iterative implementation of nearest neighbor and chamfer distance.
|
||||||
Returns lists of the unreduced loss and loss_normals. This naive
|
Returns lists of the unreduced loss and loss_normals. This naive
|
||||||
@ -174,7 +181,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
|||||||
for n in range(N):
|
for n in range(N):
|
||||||
for i1 in range(P1):
|
for i1 in range(P1):
|
||||||
for i2 in range(P2):
|
for i2 in range(P2):
|
||||||
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
|
if norm == 2:
|
||||||
|
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
|
||||||
|
elif norm == 1:
|
||||||
|
dist[n, i1, i2] = torch.sum(
|
||||||
|
torch.abs(x[n, i1, :] - y[n, i2, :])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("No support for norm %d" % (norm))
|
||||||
|
|
||||||
loss = [
|
loss = [
|
||||||
torch.min(dist, dim=2)[0], # (N, P1)
|
torch.min(dist, dim=2)[0], # (N, P1)
|
||||||
@ -208,30 +222,34 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
N, max_P1, max_P2 = 7, 10, 18
|
N, max_P1, max_P2 = 7, 10, 18
|
||||||
device = get_random_cuda_device()
|
device = get_random_cuda_device()
|
||||||
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
|
||||||
p1 = points_normals.p1
|
|
||||||
p2 = points_normals.p2
|
|
||||||
weights = points_normals.weights
|
|
||||||
p11 = p1.detach().clone()
|
|
||||||
p22 = p2.detach().clone()
|
|
||||||
p11.requires_grad = True
|
|
||||||
p22.requires_grad = True
|
|
||||||
P1 = p1.shape[1]
|
|
||||||
P2 = p2.shape[1]
|
|
||||||
|
|
||||||
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(p1, p2)
|
for norm in [1, 2]:
|
||||||
|
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||||
|
p1 = points_normals.p1
|
||||||
|
p2 = points_normals.p2
|
||||||
|
weights = points_normals.weights
|
||||||
|
p11 = p1.detach().clone()
|
||||||
|
p22 = p2.detach().clone()
|
||||||
|
p11.requires_grad = True
|
||||||
|
p22.requires_grad = True
|
||||||
|
P1 = p1.shape[1]
|
||||||
|
P2 = p2.shape[1]
|
||||||
|
|
||||||
# point_reduction = "mean".
|
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
|
||||||
loss, loss_norm = chamfer_distance(p11, p22, weights=weights)
|
p1, p2, norm=norm
|
||||||
pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
|
)
|
||||||
pred_loss *= weights
|
|
||||||
pred_loss = pred_loss.sum() / weights.sum()
|
|
||||||
|
|
||||||
self.assertClose(loss, pred_loss)
|
# point_reduction = "mean".
|
||||||
self.assertTrue(loss_norm is None)
|
loss, loss_norm = chamfer_distance(p11, p22, weights=weights, norm=norm)
|
||||||
|
pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
|
||||||
|
pred_loss *= weights
|
||||||
|
pred_loss = pred_loss.sum() / weights.sum()
|
||||||
|
|
||||||
# Check gradients
|
self.assertClose(loss, pred_loss)
|
||||||
self._check_gradients(loss, None, pred_loss, None, p1, p11, p2, p22)
|
self.assertTrue(loss_norm is None)
|
||||||
|
|
||||||
|
# Check gradients
|
||||||
|
self._check_gradients(loss, None, pred_loss, None, p1, p11, p2, p22)
|
||||||
|
|
||||||
def test_chamfer_vs_naive_pointcloud(self):
|
def test_chamfer_vs_naive_pointcloud(self):
|
||||||
"""
|
"""
|
||||||
@ -242,63 +260,67 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
N, max_P1, max_P2 = 3, 70, 70
|
N, max_P1, max_P2 = 3, 70, 70
|
||||||
device = get_random_cuda_device()
|
device = get_random_cuda_device()
|
||||||
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
|
||||||
weights = points_normals.weights
|
|
||||||
x_lengths = points_normals.p1_lengths
|
|
||||||
y_lengths = points_normals.p2_lengths
|
|
||||||
|
|
||||||
# Chamfer with tensors as input for heterogeneous pointclouds.
|
for norm in [1, 2]:
|
||||||
cham_tensor, norm_tensor = chamfer_distance(
|
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
|
||||||
points_normals.p1,
|
weights = points_normals.weights
|
||||||
points_normals.p2,
|
x_lengths = points_normals.p1_lengths
|
||||||
x_normals=points_normals.n1,
|
y_lengths = points_normals.p2_lengths
|
||||||
y_normals=points_normals.n2,
|
|
||||||
x_lengths=points_normals.p1_lengths,
|
|
||||||
y_lengths=points_normals.p2_lengths,
|
|
||||||
weights=weights,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Chamfer with pointclouds as input.
|
# Chamfer with tensors as input for heterogeneous pointclouds.
|
||||||
pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds(
|
cham_tensor, norm_tensor = chamfer_distance(
|
||||||
points_normals.cloud1, points_normals.cloud2, device=device
|
points_normals.p1,
|
||||||
)
|
points_normals.p2,
|
||||||
|
x_normals=points_normals.n1,
|
||||||
|
y_normals=points_normals.n2,
|
||||||
|
x_lengths=points_normals.p1_lengths,
|
||||||
|
y_lengths=points_normals.p2_lengths,
|
||||||
|
weights=weights,
|
||||||
|
norm=norm,
|
||||||
|
)
|
||||||
|
|
||||||
# Mean reduction point loss.
|
# Chamfer with pointclouds as input.
|
||||||
pred_loss[0] *= weights.view(N, 1)
|
pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds(
|
||||||
pred_loss[1] *= weights.view(N, 1)
|
points_normals.cloud1, points_normals.cloud2, norm=norm, device=device
|
||||||
pred_loss_mean = (
|
)
|
||||||
pred_loss[0].sum(1) / x_lengths + pred_loss[1].sum(1) / y_lengths
|
|
||||||
)
|
|
||||||
pred_loss_mean = pred_loss_mean.sum()
|
|
||||||
pred_loss_mean /= weights.sum()
|
|
||||||
|
|
||||||
# Mean reduction norm loss.
|
# Mean reduction point loss.
|
||||||
pred_norm_loss[0] *= weights.view(N, 1)
|
pred_loss[0] *= weights.view(N, 1)
|
||||||
pred_norm_loss[1] *= weights.view(N, 1)
|
pred_loss[1] *= weights.view(N, 1)
|
||||||
pred_norm_loss_mean = (
|
pred_loss_mean = (
|
||||||
pred_norm_loss[0].sum(1) / x_lengths + pred_norm_loss[1].sum(1) / y_lengths
|
pred_loss[0].sum(1) / x_lengths + pred_loss[1].sum(1) / y_lengths
|
||||||
)
|
)
|
||||||
pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum()
|
pred_loss_mean = pred_loss_mean.sum()
|
||||||
|
pred_loss_mean /= weights.sum()
|
||||||
|
|
||||||
self.assertClose(pred_loss_mean, cham_tensor)
|
# Mean reduction norm loss.
|
||||||
self.assertClose(pred_norm_loss_mean, norm_tensor)
|
pred_norm_loss[0] *= weights.view(N, 1)
|
||||||
|
pred_norm_loss[1] *= weights.view(N, 1)
|
||||||
|
pred_norm_loss_mean = (
|
||||||
|
pred_norm_loss[0].sum(1) / x_lengths
|
||||||
|
+ pred_norm_loss[1].sum(1) / y_lengths
|
||||||
|
)
|
||||||
|
pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum()
|
||||||
|
|
||||||
self._check_gradients(
|
self.assertClose(pred_loss_mean, cham_tensor)
|
||||||
cham_tensor,
|
self.assertClose(pred_norm_loss_mean, norm_tensor)
|
||||||
norm_tensor,
|
|
||||||
pred_loss_mean,
|
self._check_gradients(
|
||||||
pred_norm_loss_mean,
|
cham_tensor,
|
||||||
points_normals.cloud1.points_list(),
|
norm_tensor,
|
||||||
points_normals.p1,
|
pred_loss_mean,
|
||||||
points_normals.cloud2.points_list(),
|
pred_norm_loss_mean,
|
||||||
points_normals.p2,
|
points_normals.cloud1.points_list(),
|
||||||
points_normals.cloud1.normals_list(),
|
points_normals.p1,
|
||||||
points_normals.n1,
|
points_normals.cloud2.points_list(),
|
||||||
points_normals.cloud2.normals_list(),
|
points_normals.p2,
|
||||||
points_normals.n2,
|
points_normals.cloud1.normals_list(),
|
||||||
x_lengths,
|
points_normals.n1,
|
||||||
y_lengths,
|
points_normals.cloud2.normals_list(),
|
||||||
)
|
points_normals.n2,
|
||||||
|
x_lengths,
|
||||||
|
y_lengths,
|
||||||
|
)
|
||||||
|
|
||||||
def test_chamfer_pointcloud_object_withnormals(self):
|
def test_chamfer_pointcloud_object_withnormals(self):
|
||||||
N = 5
|
N = 5
|
||||||
@ -742,6 +764,19 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
|
|||||||
with self.assertRaisesRegex(ValueError, "Pointclouds objects or torch.Tensor"):
|
with self.assertRaisesRegex(ValueError, "Pointclouds objects or torch.Tensor"):
|
||||||
chamfer_distance(x=[1, 1, 1], y=[1, 1, 1])
|
chamfer_distance(x=[1, 1, 1], y=[1, 1, 1])
|
||||||
|
|
||||||
|
def test_invalid_norm(self):
|
||||||
|
N, P1, P2 = 7, 10, 18
|
||||||
|
device = get_random_cuda_device()
|
||||||
|
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
|
||||||
|
p1 = points_normals.p1
|
||||||
|
p2 = points_normals.p2
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
|
||||||
|
chamfer_distance(p1, p2, norm=0)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
|
||||||
|
chamfer_distance(p1, p2, norm=3)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def chamfer_with_init(
|
def chamfer_with_init(
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
|
@ -18,7 +18,9 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
|||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _knn_points_naive(p1, p2, lengths1, lengths2, K: int) -> torch.Tensor:
|
def _knn_points_naive(
|
||||||
|
p1, p2, lengths1, lengths2, K: int, norm: int = 2
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Naive PyTorch implementation of K-Nearest Neighbors.
|
Naive PyTorch implementation of K-Nearest Neighbors.
|
||||||
Returns always sorted results
|
Returns always sorted results
|
||||||
@ -42,7 +44,12 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
|||||||
pp1 = p1[n, :num1].view(num1, 1, D)
|
pp1 = p1[n, :num1].view(num1, 1, D)
|
||||||
pp2 = p2[n, :num2].view(1, num2, D)
|
pp2 = p2[n, :num2].view(1, num2, D)
|
||||||
diff = pp1 - pp2
|
diff = pp1 - pp2
|
||||||
diff = (diff * diff).sum(2)
|
if norm == 2:
|
||||||
|
diff = (diff * diff).sum(2)
|
||||||
|
elif norm == 1:
|
||||||
|
diff = diff.abs().sum(2)
|
||||||
|
else:
|
||||||
|
raise ValueError("No support for norm %d" % (norm))
|
||||||
num2 = min(num2, K)
|
num2 = min(num2, K)
|
||||||
for i in range(num1):
|
for i in range(num1):
|
||||||
dd = diff[i]
|
dd = diff[i]
|
||||||
@ -59,9 +66,10 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
|||||||
P1s = [8, 24]
|
P1s = [8, 24]
|
||||||
P2s = [8, 16, 32]
|
P2s = [8, 16, 32]
|
||||||
Ks = [1, 3, 10]
|
Ks = [1, 3, 10]
|
||||||
|
norms = [1, 2]
|
||||||
versions = [0, 1, 2, 3]
|
versions = [0, 1, 2, 3]
|
||||||
factors = [Ns, Ds, P1s, P2s, Ks]
|
factors = [Ns, Ds, P1s, P2s, Ks, norms]
|
||||||
for N, D, P1, P2, K in product(*factors):
|
for N, D, P1, P2, K, norm in product(*factors):
|
||||||
for version in versions:
|
for version in versions:
|
||||||
if version == 3 and K > 4:
|
if version == 3 and K > 4:
|
||||||
continue
|
continue
|
||||||
@ -73,9 +81,16 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
|||||||
y_cuda.requires_grad_(True)
|
y_cuda.requires_grad_(True)
|
||||||
|
|
||||||
# forward
|
# forward
|
||||||
out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K)
|
out1 = self._knn_points_naive(
|
||||||
|
x, y, lengths1=None, lengths2=None, K=K, norm=norm
|
||||||
|
)
|
||||||
out2 = knn_points(
|
out2 = knn_points(
|
||||||
x_cuda, y_cuda, K=K, version=version, return_sorted=return_sorted
|
x_cuda,
|
||||||
|
y_cuda,
|
||||||
|
K=K,
|
||||||
|
norm=norm,
|
||||||
|
version=version,
|
||||||
|
return_sorted=return_sorted,
|
||||||
)
|
)
|
||||||
if K > 1 and not return_sorted:
|
if K > 1 and not return_sorted:
|
||||||
# check out2 is not sorted
|
# check out2 is not sorted
|
||||||
@ -121,8 +136,9 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
|||||||
P1s = [8, 24]
|
P1s = [8, 24]
|
||||||
P2s = [8, 16, 32]
|
P2s = [8, 16, 32]
|
||||||
Ks = [1, 3, 10]
|
Ks = [1, 3, 10]
|
||||||
factors = [Ns, Ds, P1s, P2s, Ks]
|
norms = [1, 2]
|
||||||
for N, D, P1, P2, K in product(*factors):
|
factors = [Ns, Ds, P1s, P2s, Ks, norms]
|
||||||
|
for N, D, P1, P2, K, norm in product(*factors):
|
||||||
x = torch.rand((N, P1, D), device=device, requires_grad=True)
|
x = torch.rand((N, P1, D), device=device, requires_grad=True)
|
||||||
y = torch.rand((N, P2, D), device=device, requires_grad=True)
|
y = torch.rand((N, P2, D), device=device, requires_grad=True)
|
||||||
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
|
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
|
||||||
@ -135,9 +151,11 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
|||||||
|
|
||||||
# forward
|
# forward
|
||||||
out1 = self._knn_points_naive(
|
out1 = self._knn_points_naive(
|
||||||
x, y, lengths1=lengths1, lengths2=lengths2, K=K
|
x, y, lengths1=lengths1, lengths2=lengths2, K=K, norm=norm
|
||||||
|
)
|
||||||
|
out2 = knn_points(
|
||||||
|
x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K, norm=norm
|
||||||
)
|
)
|
||||||
out2 = knn_points(x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K)
|
|
||||||
self.assertClose(out1[0], out2[0])
|
self.assertClose(out1[0], out2[0])
|
||||||
self.assertTrue(torch.all(out1[1] == out2[1]))
|
self.assertTrue(torch.all(out1[1] == out2[1]))
|
||||||
|
|
||||||
@ -198,6 +216,17 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
|||||||
expected = all_expected[version]
|
expected = all_expected[version]
|
||||||
self.assertEqual(actual, expected)
|
self.assertEqual(actual, expected)
|
||||||
|
|
||||||
|
def test_invalid_norm(self):
|
||||||
|
device = get_random_cuda_device()
|
||||||
|
N, P1, P2, K, D = 4, 16, 12, 8, 3
|
||||||
|
x = torch.rand((N, P1, D), device=device)
|
||||||
|
y = torch.rand((N, P2, D), device=device)
|
||||||
|
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
|
||||||
|
knn_points(x, y, K=K, norm=3)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
|
||||||
|
knn_points(x, y, K=K, norm=0)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str):
|
def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str):
|
||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user