diff --git a/pytorch3d/csrc/knn/knn.cu b/pytorch3d/csrc/knn/knn.cu index ca1cc07b..19193405 100644 --- a/pytorch3d/csrc/knn/knn.cu +++ b/pytorch3d/csrc/knn/knn.cu @@ -36,7 +36,8 @@ __global__ void KNearestNeighborKernelV0( const size_t P1, const size_t P2, const size_t D, - const size_t K) { + const size_t K, + const size_t norm) { // Store both dists and indices for knn in global memory. const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); const int64_t chunks_to_do = N * chunks_per_cloud; @@ -56,7 +57,8 @@ __global__ void KNearestNeighborKernelV0( scalar_t coord1 = points1[n * P1 * D + p1 * D + d]; scalar_t coord2 = points2[n * P2 * D + p2 * D + d]; scalar_t diff = coord1 - coord2; - dist += diff * diff; + scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); + dist += norm_diff; } mink.add(dist, p2); } @@ -74,7 +76,8 @@ __global__ void KNearestNeighborKernelV1( const size_t N, const size_t P1, const size_t P2, - const size_t K) { + const size_t K, + const size_t norm) { // Same idea as the previous version, but hoist D into a template argument // so we can cache the current point in a thread-local array. We still store // the current best K dists and indices in global memory, so this should work @@ -99,7 +102,8 @@ __global__ void KNearestNeighborKernelV1( scalar_t dist = 0; for (int d = 0; d < D; ++d) { scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d]; - dist += diff * diff; + scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); + dist += norm_diff; } mink.add(dist, p2); } @@ -121,10 +125,11 @@ struct KNearestNeighborV1Functor { const size_t N, const size_t P1, const size_t P2, - const size_t K) { + const size_t K, + const size_t norm) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); KNearestNeighborKernelV1<<>>( - points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K); + points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K, norm); } }; @@ -138,7 +143,8 @@ __global__ void KNearestNeighborKernelV2( int64_t* __restrict__ idxs, const int64_t N, const int64_t P1, - const int64_t P2) { + const int64_t P2, + const size_t norm) { // Same general implementation as V2, but also hoist K into a template arg. scalar_t cur_point[D]; scalar_t min_dists[K]; @@ -161,7 +167,8 @@ __global__ void KNearestNeighborKernelV2( for (int d = 0; d < D; ++d) { int offset = n * P2 * D + p2 * D + d; scalar_t diff = cur_point[d] - points2[offset]; - dist += diff * diff; + scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); + dist += norm_diff; } mink.add(dist, p2); } @@ -186,10 +193,11 @@ struct KNearestNeighborKernelV2Functor { int64_t* __restrict__ idxs, const int64_t N, const int64_t P1, - const int64_t P2) { + const int64_t P2, + const size_t norm) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); KNearestNeighborKernelV2<<>>( - points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2); + points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm); } }; @@ -203,7 +211,8 @@ __global__ void KNearestNeighborKernelV3( int64_t* __restrict__ idxs, const size_t N, const size_t P1, - const size_t P2) { + const size_t P2, + const size_t norm) { // Same idea as V2, but use register indexing for thread-local arrays. // Enabling sorting for this version leads to huge slowdowns; I suspect // that it forces min_dists into local memory rather than registers. @@ -229,7 +238,8 @@ __global__ void KNearestNeighborKernelV3( for (int d = 0; d < D; ++d) { int offset = n * P2 * D + p2 * D + d; scalar_t diff = cur_point[d] - points2[offset]; - dist += diff * diff; + scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); + dist += norm_diff; } mink.add(dist, p2); } @@ -254,10 +264,11 @@ struct KNearestNeighborKernelV3Functor { int64_t* __restrict__ idxs, const size_t N, const size_t P1, - const size_t P2) { + const size_t P2, + const size_t norm) { cudaStream_t stream = at::cuda::getCurrentCUDAStream(); KNearestNeighborKernelV3<<>>( - points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2); + points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm); } }; @@ -305,7 +316,8 @@ std::tuple KNearestNeighborIdxCuda( const at::Tensor& p2, const at::Tensor& lengths1, const at::Tensor& lengths2, - int K, + const int norm, + const int K, int version) { // Check inputs are on the same device at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, @@ -324,6 +336,8 @@ std::tuple KNearestNeighborIdxCuda( const auto D = p2.size(2); const int64_t K_64 = K; + TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2."); + TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension"); auto long_dtype = lengths1.options().dtype(at::kLong); auto idxs = at::zeros({N, P1, K}, long_dtype); @@ -366,7 +380,8 @@ std::tuple KNearestNeighborIdxCuda( P1, P2, D, - K); + K, + norm); })); } else if (version == 1) { AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { @@ -387,7 +402,8 @@ std::tuple KNearestNeighborIdxCuda( N, P1, P2, - K); + K, + norm); })); } else if (version == 2) { AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { @@ -410,7 +426,8 @@ std::tuple KNearestNeighborIdxCuda( idxs.data_ptr(), N, P1, - P2); + P2, + norm); })); } else if (version == 3) { AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { @@ -433,7 +450,8 @@ std::tuple KNearestNeighborIdxCuda( idxs.data_ptr(), N, P1, - P2); + P2, + norm); })); } AT_CUDA_CHECK(cudaGetLastError()); @@ -459,7 +477,8 @@ __global__ void KNearestNeighborBackwardKernel( const size_t P1, const size_t P2, const size_t K, - const size_t D) { + const size_t D, + const size_t norm) { const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = gridDim.x * blockDim.x; @@ -481,8 +500,17 @@ __global__ void KNearestNeighborBackwardKernel( if (p2_idx == -1) { continue; } - const float diff = 2.0 * grad_dist * - (p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]); + float diff = 0.0; + if (norm == 1) { + float sign = + (p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d]) + ? 1.0 + : -1.0; + diff = grad_dist * sign; + } else { // norm is 2 + diff = 2.0 * grad_dist * + (p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]); + } atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff); atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff); } @@ -495,6 +523,7 @@ std::tuple KNearestNeighborBackwardCuda( const at::Tensor& lengths1, const at::Tensor& lengths2, const at::Tensor& idxs, + int norm, const at::Tensor& grad_dists) { // Check inputs are on the same device at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, @@ -547,7 +576,8 @@ std::tuple KNearestNeighborBackwardCuda( P1, P2, K, - D); + D, + norm); AT_CUDA_CHECK(cudaGetLastError()); return std::make_tuple(grad_p1, grad_p2); diff --git a/pytorch3d/csrc/knn/knn.h b/pytorch3d/csrc/knn/knn.h index aa9cfb54..7fc8d488 100644 --- a/pytorch3d/csrc/knn/knn.h +++ b/pytorch3d/csrc/knn/knn.h @@ -21,6 +21,7 @@ // containing P2 points of dimension D. // lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud. // lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud. +// norm: int specifying the norm for the distance (1 for L1, 2 for L2) // K: int giving the number of nearest points to return. // version: Integer telling which implementation to use. // @@ -41,7 +42,8 @@ std::tuple KNearestNeighborIdxCpu( const at::Tensor& p2, const at::Tensor& lengths1, const at::Tensor& lengths2, - int K); + const int norm, + const int K); // CUDA implementation std::tuple KNearestNeighborIdxCuda( @@ -49,8 +51,9 @@ std::tuple KNearestNeighborIdxCuda( const at::Tensor& p2, const at::Tensor& lengths1, const at::Tensor& lengths2, - int K, - int version); + const int norm, + const int K, + const int version); // Implementation which is exposed. std::tuple KNearestNeighborIdx( @@ -58,18 +61,20 @@ std::tuple KNearestNeighborIdx( const at::Tensor& p2, const at::Tensor& lengths1, const at::Tensor& lengths2, - int K, - int version) { + const int norm, + const int K, + const int version) { if (p1.is_cuda() || p2.is_cuda()) { #ifdef WITH_CUDA CHECK_CUDA(p1); CHECK_CUDA(p2); - return KNearestNeighborIdxCuda(p1, p2, lengths1, lengths2, K, version); + return KNearestNeighborIdxCuda( + p1, p2, lengths1, lengths2, norm, K, version); #else AT_ERROR("Not compiled with GPU support."); #endif } - return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K); + return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K); } // Compute gradients with respect to p1 and p2 @@ -86,6 +91,7 @@ std::tuple KNearestNeighborIdx( // neighbor to p1[n, i] in the cloud p2[n] is p2[n, j]. // It is padded with zeros so that it can be used easily in a later // gather() operation. This is computed from the forward pass. +// norm: int specifying the norm for the distance (1 for L1, 2 for L2) // grad_dists: FLoatTensor of shape (N, P1, K) which contains the input // gradients. // @@ -102,6 +108,7 @@ std::tuple KNearestNeighborBackwardCpu( const at::Tensor& lengths1, const at::Tensor& lengths2, const at::Tensor& idxs, + const int norm, const at::Tensor& grad_dists); // CUDA implementation @@ -111,6 +118,7 @@ std::tuple KNearestNeighborBackwardCuda( const at::Tensor& lengths1, const at::Tensor& lengths2, const at::Tensor& idxs, + const int norm, const at::Tensor& grad_dists); // Implementation which is exposed. @@ -120,19 +128,20 @@ std::tuple KNearestNeighborBackward( const at::Tensor& lengths1, const at::Tensor& lengths2, const at::Tensor& idxs, + const int norm, const at::Tensor& grad_dists) { if (p1.is_cuda() || p2.is_cuda()) { #ifdef WITH_CUDA CHECK_CUDA(p1); CHECK_CUDA(p2); return KNearestNeighborBackwardCuda( - p1, p2, lengths1, lengths2, idxs, grad_dists); + p1, p2, lengths1, lengths2, idxs, norm, grad_dists); #else AT_ERROR("Not compiled with GPU support."); #endif } return KNearestNeighborBackwardCpu( - p1, p2, lengths1, lengths2, idxs, grad_dists); + p1, p2, lengths1, lengths2, idxs, norm, grad_dists); } // Utility to check whether a KNN version can be used. diff --git a/pytorch3d/csrc/knn/knn_cpu.cpp b/pytorch3d/csrc/knn/knn_cpu.cpp index dc93fd41..9e3153a6 100644 --- a/pytorch3d/csrc/knn/knn_cpu.cpp +++ b/pytorch3d/csrc/knn/knn_cpu.cpp @@ -15,7 +15,8 @@ std::tuple KNearestNeighborIdxCpu( const at::Tensor& p2, const at::Tensor& lengths1, const at::Tensor& lengths2, - int K) { + const int norm, + const int K) { const int N = p1.size(0); const int P1 = p1.size(1); const int D = p1.size(2); @@ -41,7 +42,11 @@ std::tuple KNearestNeighborIdxCpu( float dist = 0; for (int d = 0; d < D; ++d) { float diff = p1_a[n][i1][d] - p2_a[n][i2][d]; - dist += diff * diff; + if (norm == 1) { + dist += abs(diff); + } else { // norm is 2 (default) + dist += diff * diff; + } } int size = static_cast(q.size()); if (size < K || dist < std::get<0>(q.top())) { @@ -73,6 +78,7 @@ std::tuple KNearestNeighborBackwardCpu( const at::Tensor& lengths1, const at::Tensor& lengths2, const at::Tensor& idxs, + const int norm, const at::Tensor& grad_dists) { const int N = p1.size(0); const int P1 = p1.size(1); @@ -104,8 +110,14 @@ std::tuple KNearestNeighborBackwardCpu( continue; } for (int64_t d = 0; d < D; ++d) { - const float diff = - 2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]); + float diff = 0.0; + if (norm == 1) { + float sign = (p1_a[n][i1][d] > p2_a[n][i2][d]) ? 1.0 : -1.0; + diff = grad_dists_a[n][i1][k] * sign; + } else { // norm is 2 (default) + diff = 2.0f * grad_dists_a[n][i1][k] * + (p1_a[n][i1][d] - p2_a[n][i2][d]); + } grad_p1_a[n][i1][d] += diff; grad_p2_a[n][i2][d] += -1.0f * diff; } diff --git a/pytorch3d/loss/chamfer.py b/pytorch3d/loss/chamfer.py index 744a1e2a..4d65989e 100644 --- a/pytorch3d/loss/chamfer.py +++ b/pytorch3d/loss/chamfer.py @@ -77,6 +77,7 @@ def chamfer_distance( weights=None, batch_reduction: Union[str, None] = "mean", point_reduction: str = "mean", + norm: int = 2, ): """ Chamfer distance between two pointclouds x and y. @@ -100,6 +101,7 @@ def chamfer_distance( batch, can be one of ["mean", "sum"] or None. point_reduction: Reduction operation to apply for the loss across the points, can be one of ["mean", "sum"]. + norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2. Returns: 2-element tuple containing @@ -112,6 +114,9 @@ def chamfer_distance( """ _validate_chamfer_reduction_inputs(batch_reduction, point_reduction) + if not ((norm == 1) or (norm == 2)): + raise ValueError("Support for 1 or 2 norm.") + x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals) y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals) @@ -149,8 +154,8 @@ def chamfer_distance( cham_norm_x = x.new_zeros(()) cham_norm_y = x.new_zeros(()) - x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1) - y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1) + x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, norm=norm, K=1) + y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, norm=norm, K=1) cham_x = x_nn.dists[..., 0] # (N, P1) cham_y = y_nn.dists[..., 0] # (N, P2) diff --git a/pytorch3d/ops/ball_query.py b/pytorch3d/ops/ball_query.py index c82a0e7d..821df06f 100644 --- a/pytorch3d/ops/ball_query.py +++ b/pytorch3d/ops/ball_query.py @@ -43,8 +43,9 @@ class _ball_query(Function): p2 = p2.float() # Reuse the KNN backward function + # by default, norm is 2 grad_p1, grad_p2 = _C.knn_points_backward( - p1, p2, lengths1, lengths2, idx, grad_dists + p1, p2, lengths1, lengths2, idx, 2, grad_dists ) return grad_p1, grad_p2, None, None, None, None diff --git a/pytorch3d/ops/knn.py b/pytorch3d/ops/knn.py index bd6f673f..72e3c289 100644 --- a/pytorch3d/ops/knn.py +++ b/pytorch3d/ops/knn.py @@ -24,7 +24,15 @@ class _knn_points(Function): @staticmethod # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. def forward( - ctx, p1, p2, lengths1, lengths2, K, version, return_sorted: bool = True + ctx, + p1, + p2, + lengths1, + lengths2, + K, + version, + norm: int = 2, + return_sorted: bool = True, ): """ K-Nearest neighbors on point clouds. @@ -43,6 +51,7 @@ class _knn_points(Function): K: Integer giving the number of nearest neighbors to return. version: Which KNN implementation to use in the backend. If version=-1, the correct implementation is selected based on the shapes of the inputs. + norm: (int) indicating the norm. Only supports 1 (for L1) and 2 (for L2). return_sorted: (bool) whether to return the nearest neighbors sorted in ascending order of distance. @@ -57,8 +66,10 @@ class _knn_points(Function): neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points. """ + if not ((norm == 1) or (norm == 2)): + raise ValueError("Support for 1 or 2 norm.") - idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version) + idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version) # sort KNN in ascending order if K > 1 if K > 1 and return_sorted: @@ -78,12 +89,14 @@ class _knn_points(Function): ctx.save_for_backward(p1, p2, lengths1, lengths2, idx) ctx.mark_non_differentiable(idx) + ctx.norm = norm return dists, idx @staticmethod @once_differentiable def backward(ctx, grad_dists, grad_idx): p1, p2, lengths1, lengths2, idx = ctx.saved_tensors + norm = ctx.norm # TODO(gkioxari) Change cast to floats once we add support for doubles. if not (grad_dists.dtype == torch.float32): grad_dists = grad_dists.float() @@ -92,9 +105,9 @@ class _knn_points(Function): if not (p2.dtype == torch.float32): p2 = p2.float() grad_p1, grad_p2 = _C.knn_points_backward( - p1, p2, lengths1, lengths2, idx, grad_dists + p1, p2, lengths1, lengths2, idx, norm, grad_dists ) - return grad_p1, grad_p2, None, None, None, None, None + return grad_p1, grad_p2, None, None, None, None, None, None def knn_points( @@ -102,6 +115,7 @@ def knn_points( p2: torch.Tensor, lengths1: Union[torch.Tensor, None] = None, lengths2: Union[torch.Tensor, None] = None, + norm: int = 2, K: int = 1, version: int = -1, return_nn: bool = False, @@ -121,6 +135,7 @@ def knn_points( lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the length of each pointcloud in p2. Or None to indicate that every cloud has length P2. + norm: Integer indicating the norm of the distance. Supports only 1 for L1, 2 for L2. K: Integer giving the number of nearest neighbors to return. version: Which KNN implementation to use in the backend. If version=-1, the correct implementation is selected based on the shapes of the inputs. @@ -172,7 +187,7 @@ def knn_points( # pyre-fixme[16]: `_knn_points` has no attribute `apply`. p1_dists, p1_idx = _knn_points.apply( - p1, p2, lengths1, lengths2, K, version, return_sorted + p1, p2, lengths1, lengths2, K, version, norm, return_sorted ) p2_nn = None diff --git a/tests/test_chamfer.py b/tests/test_chamfer.py index 68d4b690..0231f41f 100644 --- a/tests/test_chamfer.py +++ b/tests/test_chamfer.py @@ -87,7 +87,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): ) @staticmethod - def chamfer_distance_naive_pointclouds(p1, p2, device="cpu"): + def chamfer_distance_naive_pointclouds(p1, p2, norm: int = 2, device="cpu"): """ Naive iterative implementation of nearest neighbor and chamfer distance. x and y are assumed to be pointclouds objects with points and optionally normals. @@ -121,7 +121,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): for n in range(N): for i1 in range(x_lengths[n]): for i2 in range(y_lengths[n]): - dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2) + if norm == 2: + dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2) + elif norm == 1: + dist[n, i1, i2] = torch.sum( + torch.abs(x[n, i1, :] - y[n, i2, :]) + ) + else: + raise ValueError("No support for norm %d" % (norm)) x_dist = torch.min(dist, dim=2)[0] # (N, P1) y_dist = torch.min(dist, dim=1)[0] # (N, P2) @@ -159,7 +166,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): return loss, lnorm @staticmethod - def chamfer_distance_naive(x, y, x_normals=None, y_normals=None): + def chamfer_distance_naive(x, y, x_normals=None, y_normals=None, norm: int = 2): """ Naive iterative implementation of nearest neighbor and chamfer distance. Returns lists of the unreduced loss and loss_normals. This naive @@ -174,7 +181,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): for n in range(N): for i1 in range(P1): for i2 in range(P2): - dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2) + if norm == 2: + dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2) + elif norm == 1: + dist[n, i1, i2] = torch.sum( + torch.abs(x[n, i1, :] - y[n, i2, :]) + ) + else: + raise ValueError("No support for norm %d" % (norm)) loss = [ torch.min(dist, dim=2)[0], # (N, P1) @@ -208,30 +222,34 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): """ N, max_P1, max_P2 = 7, 10, 18 device = get_random_cuda_device() - points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) - p1 = points_normals.p1 - p2 = points_normals.p2 - weights = points_normals.weights - p11 = p1.detach().clone() - p22 = p2.detach().clone() - p11.requires_grad = True - p22.requires_grad = True - P1 = p1.shape[1] - P2 = p2.shape[1] - pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(p1, p2) + for norm in [1, 2]: + points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) + p1 = points_normals.p1 + p2 = points_normals.p2 + weights = points_normals.weights + p11 = p1.detach().clone() + p22 = p2.detach().clone() + p11.requires_grad = True + p22.requires_grad = True + P1 = p1.shape[1] + P2 = p2.shape[1] - # point_reduction = "mean". - loss, loss_norm = chamfer_distance(p11, p22, weights=weights) - pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2 - pred_loss *= weights - pred_loss = pred_loss.sum() / weights.sum() + pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive( + p1, p2, norm=norm + ) - self.assertClose(loss, pred_loss) - self.assertTrue(loss_norm is None) + # point_reduction = "mean". + loss, loss_norm = chamfer_distance(p11, p22, weights=weights, norm=norm) + pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2 + pred_loss *= weights + pred_loss = pred_loss.sum() / weights.sum() - # Check gradients - self._check_gradients(loss, None, pred_loss, None, p1, p11, p2, p22) + self.assertClose(loss, pred_loss) + self.assertTrue(loss_norm is None) + + # Check gradients + self._check_gradients(loss, None, pred_loss, None, p1, p11, p2, p22) def test_chamfer_vs_naive_pointcloud(self): """ @@ -242,63 +260,67 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): """ N, max_P1, max_P2 = 3, 70, 70 device = get_random_cuda_device() - points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) - weights = points_normals.weights - x_lengths = points_normals.p1_lengths - y_lengths = points_normals.p2_lengths - # Chamfer with tensors as input for heterogeneous pointclouds. - cham_tensor, norm_tensor = chamfer_distance( - points_normals.p1, - points_normals.p2, - x_normals=points_normals.n1, - y_normals=points_normals.n2, - x_lengths=points_normals.p1_lengths, - y_lengths=points_normals.p2_lengths, - weights=weights, - ) + for norm in [1, 2]: + points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device) + weights = points_normals.weights + x_lengths = points_normals.p1_lengths + y_lengths = points_normals.p2_lengths - # Chamfer with pointclouds as input. - pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds( - points_normals.cloud1, points_normals.cloud2, device=device - ) + # Chamfer with tensors as input for heterogeneous pointclouds. + cham_tensor, norm_tensor = chamfer_distance( + points_normals.p1, + points_normals.p2, + x_normals=points_normals.n1, + y_normals=points_normals.n2, + x_lengths=points_normals.p1_lengths, + y_lengths=points_normals.p2_lengths, + weights=weights, + norm=norm, + ) - # Mean reduction point loss. - pred_loss[0] *= weights.view(N, 1) - pred_loss[1] *= weights.view(N, 1) - pred_loss_mean = ( - pred_loss[0].sum(1) / x_lengths + pred_loss[1].sum(1) / y_lengths - ) - pred_loss_mean = pred_loss_mean.sum() - pred_loss_mean /= weights.sum() + # Chamfer with pointclouds as input. + pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds( + points_normals.cloud1, points_normals.cloud2, norm=norm, device=device + ) - # Mean reduction norm loss. - pred_norm_loss[0] *= weights.view(N, 1) - pred_norm_loss[1] *= weights.view(N, 1) - pred_norm_loss_mean = ( - pred_norm_loss[0].sum(1) / x_lengths + pred_norm_loss[1].sum(1) / y_lengths - ) - pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum() + # Mean reduction point loss. + pred_loss[0] *= weights.view(N, 1) + pred_loss[1] *= weights.view(N, 1) + pred_loss_mean = ( + pred_loss[0].sum(1) / x_lengths + pred_loss[1].sum(1) / y_lengths + ) + pred_loss_mean = pred_loss_mean.sum() + pred_loss_mean /= weights.sum() - self.assertClose(pred_loss_mean, cham_tensor) - self.assertClose(pred_norm_loss_mean, norm_tensor) + # Mean reduction norm loss. + pred_norm_loss[0] *= weights.view(N, 1) + pred_norm_loss[1] *= weights.view(N, 1) + pred_norm_loss_mean = ( + pred_norm_loss[0].sum(1) / x_lengths + + pred_norm_loss[1].sum(1) / y_lengths + ) + pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum() - self._check_gradients( - cham_tensor, - norm_tensor, - pred_loss_mean, - pred_norm_loss_mean, - points_normals.cloud1.points_list(), - points_normals.p1, - points_normals.cloud2.points_list(), - points_normals.p2, - points_normals.cloud1.normals_list(), - points_normals.n1, - points_normals.cloud2.normals_list(), - points_normals.n2, - x_lengths, - y_lengths, - ) + self.assertClose(pred_loss_mean, cham_tensor) + self.assertClose(pred_norm_loss_mean, norm_tensor) + + self._check_gradients( + cham_tensor, + norm_tensor, + pred_loss_mean, + pred_norm_loss_mean, + points_normals.cloud1.points_list(), + points_normals.p1, + points_normals.cloud2.points_list(), + points_normals.p2, + points_normals.cloud1.normals_list(), + points_normals.n1, + points_normals.cloud2.normals_list(), + points_normals.n2, + x_lengths, + y_lengths, + ) def test_chamfer_pointcloud_object_withnormals(self): N = 5 @@ -742,6 +764,19 @@ class TestChamfer(TestCaseMixin, unittest.TestCase): with self.assertRaisesRegex(ValueError, "Pointclouds objects or torch.Tensor"): chamfer_distance(x=[1, 1, 1], y=[1, 1, 1]) + def test_invalid_norm(self): + N, P1, P2 = 7, 10, 18 + device = get_random_cuda_device() + points_normals = TestChamfer.init_pointclouds(N, P1, P2, device) + p1 = points_normals.p1 + p2 = points_normals.p2 + + with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."): + chamfer_distance(p1, p2, norm=0) + + with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."): + chamfer_distance(p1, p2, norm=3) + @staticmethod def chamfer_with_init( batch_size: int, diff --git a/tests/test_knn.py b/tests/test_knn.py index 8e87fcc1..6318ff0c 100644 --- a/tests/test_knn.py +++ b/tests/test_knn.py @@ -18,7 +18,9 @@ class TestKNN(TestCaseMixin, unittest.TestCase): torch.manual_seed(1) @staticmethod - def _knn_points_naive(p1, p2, lengths1, lengths2, K: int) -> torch.Tensor: + def _knn_points_naive( + p1, p2, lengths1, lengths2, K: int, norm: int = 2 + ) -> torch.Tensor: """ Naive PyTorch implementation of K-Nearest Neighbors. Returns always sorted results @@ -42,7 +44,12 @@ class TestKNN(TestCaseMixin, unittest.TestCase): pp1 = p1[n, :num1].view(num1, 1, D) pp2 = p2[n, :num2].view(1, num2, D) diff = pp1 - pp2 - diff = (diff * diff).sum(2) + if norm == 2: + diff = (diff * diff).sum(2) + elif norm == 1: + diff = diff.abs().sum(2) + else: + raise ValueError("No support for norm %d" % (norm)) num2 = min(num2, K) for i in range(num1): dd = diff[i] @@ -59,9 +66,10 @@ class TestKNN(TestCaseMixin, unittest.TestCase): P1s = [8, 24] P2s = [8, 16, 32] Ks = [1, 3, 10] + norms = [1, 2] versions = [0, 1, 2, 3] - factors = [Ns, Ds, P1s, P2s, Ks] - for N, D, P1, P2, K in product(*factors): + factors = [Ns, Ds, P1s, P2s, Ks, norms] + for N, D, P1, P2, K, norm in product(*factors): for version in versions: if version == 3 and K > 4: continue @@ -73,9 +81,16 @@ class TestKNN(TestCaseMixin, unittest.TestCase): y_cuda.requires_grad_(True) # forward - out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K) + out1 = self._knn_points_naive( + x, y, lengths1=None, lengths2=None, K=K, norm=norm + ) out2 = knn_points( - x_cuda, y_cuda, K=K, version=version, return_sorted=return_sorted + x_cuda, + y_cuda, + K=K, + norm=norm, + version=version, + return_sorted=return_sorted, ) if K > 1 and not return_sorted: # check out2 is not sorted @@ -121,8 +136,9 @@ class TestKNN(TestCaseMixin, unittest.TestCase): P1s = [8, 24] P2s = [8, 16, 32] Ks = [1, 3, 10] - factors = [Ns, Ds, P1s, P2s, Ks] - for N, D, P1, P2, K in product(*factors): + norms = [1, 2] + factors = [Ns, Ds, P1s, P2s, Ks, norms] + for N, D, P1, P2, K, norm in product(*factors): x = torch.rand((N, P1, D), device=device, requires_grad=True) y = torch.rand((N, P2, D), device=device, requires_grad=True) lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device) @@ -135,9 +151,11 @@ class TestKNN(TestCaseMixin, unittest.TestCase): # forward out1 = self._knn_points_naive( - x, y, lengths1=lengths1, lengths2=lengths2, K=K + x, y, lengths1=lengths1, lengths2=lengths2, K=K, norm=norm + ) + out2 = knn_points( + x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K, norm=norm ) - out2 = knn_points(x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K) self.assertClose(out1[0], out2[0]) self.assertTrue(torch.all(out1[1] == out2[1])) @@ -198,6 +216,17 @@ class TestKNN(TestCaseMixin, unittest.TestCase): expected = all_expected[version] self.assertEqual(actual, expected) + def test_invalid_norm(self): + device = get_random_cuda_device() + N, P1, P2, K, D = 4, 16, 12, 8, 3 + x = torch.rand((N, P1, D), device=device) + y = torch.rand((N, P2, D), device=device) + with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."): + knn_points(x, y, K=K, norm=3) + + with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."): + knn_points(x, y, K=K, norm=0) + @staticmethod def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str): device = torch.device(device)