add L1 support for KNN & Chamfer

Summary:
Added L1 norm for KNN and chamfer op
* The norm is now specified with a variable `norm` which can only be 1 or 2

Reviewed By: bottler

Differential Revision: D35419637

fbshipit-source-id: 77813fec650b30c28342af90d5ed02c89133e136
This commit is contained in:
Georgia Gkioxari 2022-04-10 10:27:20 -07:00 committed by Facebook GitHub Bot
parent 4b94649f7b
commit 67fff956a2
8 changed files with 265 additions and 129 deletions

View File

@ -36,7 +36,8 @@ __global__ void KNearestNeighborKernelV0(
const size_t P1,
const size_t P2,
const size_t D,
const size_t K) {
const size_t K,
const size_t norm) {
// Store both dists and indices for knn in global memory.
const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
@ -56,7 +57,8 @@ __global__ void KNearestNeighborKernelV0(
scalar_t coord1 = points1[n * P1 * D + p1 * D + d];
scalar_t coord2 = points2[n * P2 * D + p2 * D + d];
scalar_t diff = coord1 - coord2;
dist += diff * diff;
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
}
mink.add(dist, p2);
}
@ -74,7 +76,8 @@ __global__ void KNearestNeighborKernelV1(
const size_t N,
const size_t P1,
const size_t P2,
const size_t K) {
const size_t K,
const size_t norm) {
// Same idea as the previous version, but hoist D into a template argument
// so we can cache the current point in a thread-local array. We still store
// the current best K dists and indices in global memory, so this should work
@ -99,7 +102,8 @@ __global__ void KNearestNeighborKernelV1(
scalar_t dist = 0;
for (int d = 0; d < D; ++d) {
scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d];
dist += diff * diff;
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
}
mink.add(dist, p2);
}
@ -121,10 +125,11 @@ struct KNearestNeighborV1Functor {
const size_t N,
const size_t P1,
const size_t P2,
const size_t K) {
const size_t K,
const size_t norm) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV1<scalar_t, D><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K);
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K, norm);
}
};
@ -138,7 +143,8 @@ __global__ void KNearestNeighborKernelV2(
int64_t* __restrict__ idxs,
const int64_t N,
const int64_t P1,
const int64_t P2) {
const int64_t P2,
const size_t norm) {
// Same general implementation as V2, but also hoist K into a template arg.
scalar_t cur_point[D];
scalar_t min_dists[K];
@ -161,7 +167,8 @@ __global__ void KNearestNeighborKernelV2(
for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d;
scalar_t diff = cur_point[d] - points2[offset];
dist += diff * diff;
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
}
mink.add(dist, p2);
}
@ -186,10 +193,11 @@ struct KNearestNeighborKernelV2Functor {
int64_t* __restrict__ idxs,
const int64_t N,
const int64_t P1,
const int64_t P2) {
const int64_t P2,
const size_t norm) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV2<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
}
};
@ -203,7 +211,8 @@ __global__ void KNearestNeighborKernelV3(
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2) {
const size_t P2,
const size_t norm) {
// Same idea as V2, but use register indexing for thread-local arrays.
// Enabling sorting for this version leads to huge slowdowns; I suspect
// that it forces min_dists into local memory rather than registers.
@ -229,7 +238,8 @@ __global__ void KNearestNeighborKernelV3(
for (int d = 0; d < D; ++d) {
int offset = n * P2 * D + p2 * D + d;
scalar_t diff = cur_point[d] - points2[offset];
dist += diff * diff;
scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
dist += norm_diff;
}
mink.add(dist, p2);
}
@ -254,10 +264,11 @@ struct KNearestNeighborKernelV3Functor {
int64_t* __restrict__ idxs,
const size_t N,
const size_t P1,
const size_t P2) {
const size_t P2,
const size_t norm) {
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
KNearestNeighborKernelV3<scalar_t, D, K><<<blocks, threads, 0, stream>>>(
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2);
points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
}
};
@ -305,7 +316,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
const int norm,
const int K,
int version) {
// Check inputs are on the same device
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
@ -324,6 +336,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const auto D = p2.size(2);
const int64_t K_64 = K;
TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2.");
TORCH_CHECK(p2.size(2) == D, "Point sets must have the same last dimension");
auto long_dtype = lengths1.options().dtype(at::kLong);
auto idxs = at::zeros({N, P1, K}, long_dtype);
@ -366,7 +380,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
P1,
P2,
D,
K);
K,
norm);
}));
} else if (version == 1) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
@ -387,7 +402,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
N,
P1,
P2,
K);
K,
norm);
}));
} else if (version == 2) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
@ -410,7 +426,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
idxs.data_ptr<int64_t>(),
N,
P1,
P2);
P2,
norm);
}));
} else if (version == 3) {
AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
@ -433,7 +450,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
idxs.data_ptr<int64_t>(),
N,
P1,
P2);
P2,
norm);
}));
}
AT_CUDA_CHECK(cudaGetLastError());
@ -459,7 +477,8 @@ __global__ void KNearestNeighborBackwardKernel(
const size_t P1,
const size_t P2,
const size_t K,
const size_t D) {
const size_t D,
const size_t norm) {
const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
const size_t stride = gridDim.x * blockDim.x;
@ -481,8 +500,17 @@ __global__ void KNearestNeighborBackwardKernel(
if (p2_idx == -1) {
continue;
}
const float diff = 2.0 * grad_dist *
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
float diff = 0.0;
if (norm == 1) {
float sign =
(p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d])
? 1.0
: -1.0;
diff = grad_dist * sign;
} else { // norm is 2
diff = 2.0 * grad_dist *
(p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
}
atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
}
@ -495,6 +523,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
int norm,
const at::Tensor& grad_dists) {
// Check inputs are on the same device
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
@ -547,7 +576,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
P1,
P2,
K,
D);
D,
norm);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_p1, grad_p2);

View File

@ -21,6 +21,7 @@
// containing P2 points of dimension D.
// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
// K: int giving the number of nearest points to return.
// version: Integer telling which implementation to use.
//
@ -41,7 +42,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K);
const int norm,
const int K);
// CUDA implementation
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
@ -49,8 +51,9 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCuda(
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
int version);
const int norm,
const int K,
const int version);
// Implementation which is exposed.
std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
@ -58,18 +61,20 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
int version) {
const int norm,
const int K,
const int version) {
if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(p1);
CHECK_CUDA(p2);
return KNearestNeighborIdxCuda(p1, p2, lengths1, lengths2, K, version);
return KNearestNeighborIdxCuda(
p1, p2, lengths1, lengths2, norm, K, version);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, K);
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K);
}
// Compute gradients with respect to p1 and p2
@ -86,6 +91,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
// It is padded with zeros so that it can be used easily in a later
// gather() operation. This is computed from the forward pass.
// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input
// gradients.
//
@ -102,6 +108,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const int norm,
const at::Tensor& grad_dists);
// CUDA implementation
@ -111,6 +118,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCuda(
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const int norm,
const at::Tensor& grad_dists);
// Implementation which is exposed.
@ -120,19 +128,20 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const int norm,
const at::Tensor& grad_dists) {
if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(p1);
CHECK_CUDA(p2);
return KNearestNeighborBackwardCuda(
p1, p2, lengths1, lengths2, idxs, grad_dists);
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return KNearestNeighborBackwardCpu(
p1, p2, lengths1, lengths2, idxs, grad_dists);
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
}
// Utility to check whether a KNN version can be used.

View File

@ -15,7 +15,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K) {
const int norm,
const int K) {
const int N = p1.size(0);
const int P1 = p1.size(1);
const int D = p1.size(2);
@ -41,7 +42,11 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdxCpu(
float dist = 0;
for (int d = 0; d < D; ++d) {
float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
dist += diff * diff;
if (norm == 1) {
dist += abs(diff);
} else { // norm is 2 (default)
dist += diff * diff;
}
}
int size = static_cast<int>(q.size());
if (size < K || dist < std::get<0>(q.top())) {
@ -73,6 +78,7 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const at::Tensor& idxs,
const int norm,
const at::Tensor& grad_dists) {
const int N = p1.size(0);
const int P1 = p1.size(1);
@ -104,8 +110,14 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackwardCpu(
continue;
}
for (int64_t d = 0; d < D; ++d) {
const float diff =
2.0f * grad_dists_a[n][i1][k] * (p1_a[n][i1][d] - p2_a[n][i2][d]);
float diff = 0.0;
if (norm == 1) {
float sign = (p1_a[n][i1][d] > p2_a[n][i2][d]) ? 1.0 : -1.0;
diff = grad_dists_a[n][i1][k] * sign;
} else { // norm is 2 (default)
diff = 2.0f * grad_dists_a[n][i1][k] *
(p1_a[n][i1][d] - p2_a[n][i2][d]);
}
grad_p1_a[n][i1][d] += diff;
grad_p2_a[n][i2][d] += -1.0f * diff;
}

View File

@ -77,6 +77,7 @@ def chamfer_distance(
weights=None,
batch_reduction: Union[str, None] = "mean",
point_reduction: str = "mean",
norm: int = 2,
):
"""
Chamfer distance between two pointclouds x and y.
@ -100,6 +101,7 @@ def chamfer_distance(
batch, can be one of ["mean", "sum"] or None.
point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["mean", "sum"].
norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
Returns:
2-element tuple containing
@ -112,6 +114,9 @@ def chamfer_distance(
"""
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
if not ((norm == 1) or (norm == 2)):
raise ValueError("Support for 1 or 2 norm.")
x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
@ -149,8 +154,8 @@ def chamfer_distance(
cham_norm_x = x.new_zeros(())
cham_norm_y = x.new_zeros(())
x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1)
y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1)
x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, norm=norm, K=1)
y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, norm=norm, K=1)
cham_x = x_nn.dists[..., 0] # (N, P1)
cham_y = y_nn.dists[..., 0] # (N, P2)

View File

@ -43,8 +43,9 @@ class _ball_query(Function):
p2 = p2.float()
# Reuse the KNN backward function
# by default, norm is 2
grad_p1, grad_p2 = _C.knn_points_backward(
p1, p2, lengths1, lengths2, idx, grad_dists
p1, p2, lengths1, lengths2, idx, 2, grad_dists
)
return grad_p1, grad_p2, None, None, None, None

View File

@ -24,7 +24,15 @@ class _knn_points(Function):
@staticmethod
# pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
def forward(
ctx, p1, p2, lengths1, lengths2, K, version, return_sorted: bool = True
ctx,
p1,
p2,
lengths1,
lengths2,
K,
version,
norm: int = 2,
return_sorted: bool = True,
):
"""
K-Nearest neighbors on point clouds.
@ -43,6 +51,7 @@ class _knn_points(Function):
K: Integer giving the number of nearest neighbors to return.
version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of the inputs.
norm: (int) indicating the norm. Only supports 1 (for L1) and 2 (for L2).
return_sorted: (bool) whether to return the nearest neighbors sorted in
ascending order of distance.
@ -57,8 +66,10 @@ class _knn_points(Function):
neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points.
"""
if not ((norm == 1) or (norm == 2)):
raise ValueError("Support for 1 or 2 norm.")
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, K, version)
idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version)
# sort KNN in ascending order if K > 1
if K > 1 and return_sorted:
@ -78,12 +89,14 @@ class _knn_points(Function):
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
ctx.mark_non_differentiable(idx)
ctx.norm = norm
return dists, idx
@staticmethod
@once_differentiable
def backward(ctx, grad_dists, grad_idx):
p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
norm = ctx.norm
# TODO(gkioxari) Change cast to floats once we add support for doubles.
if not (grad_dists.dtype == torch.float32):
grad_dists = grad_dists.float()
@ -92,9 +105,9 @@ class _knn_points(Function):
if not (p2.dtype == torch.float32):
p2 = p2.float()
grad_p1, grad_p2 = _C.knn_points_backward(
p1, p2, lengths1, lengths2, idx, grad_dists
p1, p2, lengths1, lengths2, idx, norm, grad_dists
)
return grad_p1, grad_p2, None, None, None, None, None
return grad_p1, grad_p2, None, None, None, None, None, None
def knn_points(
@ -102,6 +115,7 @@ def knn_points(
p2: torch.Tensor,
lengths1: Union[torch.Tensor, None] = None,
lengths2: Union[torch.Tensor, None] = None,
norm: int = 2,
K: int = 1,
version: int = -1,
return_nn: bool = False,
@ -121,6 +135,7 @@ def knn_points(
lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
length of each pointcloud in p2. Or None to indicate that every cloud has
length P2.
norm: Integer indicating the norm of the distance. Supports only 1 for L1, 2 for L2.
K: Integer giving the number of nearest neighbors to return.
version: Which KNN implementation to use in the backend. If version=-1,
the correct implementation is selected based on the shapes of the inputs.
@ -172,7 +187,7 @@ def knn_points(
# pyre-fixme[16]: `_knn_points` has no attribute `apply`.
p1_dists, p1_idx = _knn_points.apply(
p1, p2, lengths1, lengths2, K, version, return_sorted
p1, p2, lengths1, lengths2, K, version, norm, return_sorted
)
p2_nn = None

View File

@ -87,7 +87,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
)
@staticmethod
def chamfer_distance_naive_pointclouds(p1, p2, device="cpu"):
def chamfer_distance_naive_pointclouds(p1, p2, norm: int = 2, device="cpu"):
"""
Naive iterative implementation of nearest neighbor and chamfer distance.
x and y are assumed to be pointclouds objects with points and optionally normals.
@ -121,7 +121,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
for n in range(N):
for i1 in range(x_lengths[n]):
for i2 in range(y_lengths[n]):
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
if norm == 2:
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
elif norm == 1:
dist[n, i1, i2] = torch.sum(
torch.abs(x[n, i1, :] - y[n, i2, :])
)
else:
raise ValueError("No support for norm %d" % (norm))
x_dist = torch.min(dist, dim=2)[0] # (N, P1)
y_dist = torch.min(dist, dim=1)[0] # (N, P2)
@ -159,7 +166,7 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
return loss, lnorm
@staticmethod
def chamfer_distance_naive(x, y, x_normals=None, y_normals=None):
def chamfer_distance_naive(x, y, x_normals=None, y_normals=None, norm: int = 2):
"""
Naive iterative implementation of nearest neighbor and chamfer distance.
Returns lists of the unreduced loss and loss_normals. This naive
@ -174,7 +181,14 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
for n in range(N):
for i1 in range(P1):
for i2 in range(P2):
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
if norm == 2:
dist[n, i1, i2] = torch.sum((x[n, i1, :] - y[n, i2, :]) ** 2)
elif norm == 1:
dist[n, i1, i2] = torch.sum(
torch.abs(x[n, i1, :] - y[n, i2, :])
)
else:
raise ValueError("No support for norm %d" % (norm))
loss = [
torch.min(dist, dim=2)[0], # (N, P1)
@ -208,30 +222,34 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
"""
N, max_P1, max_P2 = 7, 10, 18
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
weights = points_normals.weights
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True
P1 = p1.shape[1]
P2 = p2.shape[1]
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(p1, p2)
for norm in [1, 2]:
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
weights = points_normals.weights
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True
P1 = p1.shape[1]
P2 = p2.shape[1]
# point_reduction = "mean".
loss, loss_norm = chamfer_distance(p11, p22, weights=weights)
pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
pred_loss *= weights
pred_loss = pred_loss.sum() / weights.sum()
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, norm=norm
)
self.assertClose(loss, pred_loss)
self.assertTrue(loss_norm is None)
# point_reduction = "mean".
loss, loss_norm = chamfer_distance(p11, p22, weights=weights, norm=norm)
pred_loss = pred_loss[0].sum(1) / P1 + pred_loss[1].sum(1) / P2
pred_loss *= weights
pred_loss = pred_loss.sum() / weights.sum()
# Check gradients
self._check_gradients(loss, None, pred_loss, None, p1, p11, p2, p22)
self.assertClose(loss, pred_loss)
self.assertTrue(loss_norm is None)
# Check gradients
self._check_gradients(loss, None, pred_loss, None, p1, p11, p2, p22)
def test_chamfer_vs_naive_pointcloud(self):
"""
@ -242,63 +260,67 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
"""
N, max_P1, max_P2 = 3, 70, 70
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
weights = points_normals.weights
x_lengths = points_normals.p1_lengths
y_lengths = points_normals.p2_lengths
# Chamfer with tensors as input for heterogeneous pointclouds.
cham_tensor, norm_tensor = chamfer_distance(
points_normals.p1,
points_normals.p2,
x_normals=points_normals.n1,
y_normals=points_normals.n2,
x_lengths=points_normals.p1_lengths,
y_lengths=points_normals.p2_lengths,
weights=weights,
)
for norm in [1, 2]:
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
weights = points_normals.weights
x_lengths = points_normals.p1_lengths
y_lengths = points_normals.p2_lengths
# Chamfer with pointclouds as input.
pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds(
points_normals.cloud1, points_normals.cloud2, device=device
)
# Chamfer with tensors as input for heterogeneous pointclouds.
cham_tensor, norm_tensor = chamfer_distance(
points_normals.p1,
points_normals.p2,
x_normals=points_normals.n1,
y_normals=points_normals.n2,
x_lengths=points_normals.p1_lengths,
y_lengths=points_normals.p2_lengths,
weights=weights,
norm=norm,
)
# Mean reduction point loss.
pred_loss[0] *= weights.view(N, 1)
pred_loss[1] *= weights.view(N, 1)
pred_loss_mean = (
pred_loss[0].sum(1) / x_lengths + pred_loss[1].sum(1) / y_lengths
)
pred_loss_mean = pred_loss_mean.sum()
pred_loss_mean /= weights.sum()
# Chamfer with pointclouds as input.
pred_loss, pred_norm_loss = TestChamfer.chamfer_distance_naive_pointclouds(
points_normals.cloud1, points_normals.cloud2, norm=norm, device=device
)
# Mean reduction norm loss.
pred_norm_loss[0] *= weights.view(N, 1)
pred_norm_loss[1] *= weights.view(N, 1)
pred_norm_loss_mean = (
pred_norm_loss[0].sum(1) / x_lengths + pred_norm_loss[1].sum(1) / y_lengths
)
pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum()
# Mean reduction point loss.
pred_loss[0] *= weights.view(N, 1)
pred_loss[1] *= weights.view(N, 1)
pred_loss_mean = (
pred_loss[0].sum(1) / x_lengths + pred_loss[1].sum(1) / y_lengths
)
pred_loss_mean = pred_loss_mean.sum()
pred_loss_mean /= weights.sum()
self.assertClose(pred_loss_mean, cham_tensor)
self.assertClose(pred_norm_loss_mean, norm_tensor)
# Mean reduction norm loss.
pred_norm_loss[0] *= weights.view(N, 1)
pred_norm_loss[1] *= weights.view(N, 1)
pred_norm_loss_mean = (
pred_norm_loss[0].sum(1) / x_lengths
+ pred_norm_loss[1].sum(1) / y_lengths
)
pred_norm_loss_mean = pred_norm_loss_mean.sum() / weights.sum()
self._check_gradients(
cham_tensor,
norm_tensor,
pred_loss_mean,
pred_norm_loss_mean,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
points_normals.cloud1.normals_list(),
points_normals.n1,
points_normals.cloud2.normals_list(),
points_normals.n2,
x_lengths,
y_lengths,
)
self.assertClose(pred_loss_mean, cham_tensor)
self.assertClose(pred_norm_loss_mean, norm_tensor)
self._check_gradients(
cham_tensor,
norm_tensor,
pred_loss_mean,
pred_norm_loss_mean,
points_normals.cloud1.points_list(),
points_normals.p1,
points_normals.cloud2.points_list(),
points_normals.p2,
points_normals.cloud1.normals_list(),
points_normals.n1,
points_normals.cloud2.normals_list(),
points_normals.n2,
x_lengths,
y_lengths,
)
def test_chamfer_pointcloud_object_withnormals(self):
N = 5
@ -742,6 +764,19 @@ class TestChamfer(TestCaseMixin, unittest.TestCase):
with self.assertRaisesRegex(ValueError, "Pointclouds objects or torch.Tensor"):
chamfer_distance(x=[1, 1, 1], y=[1, 1, 1])
def test_invalid_norm(self):
N, P1, P2 = 7, 10, 18
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
chamfer_distance(p1, p2, norm=0)
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
chamfer_distance(p1, p2, norm=3)
@staticmethod
def chamfer_with_init(
batch_size: int,

View File

@ -18,7 +18,9 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
torch.manual_seed(1)
@staticmethod
def _knn_points_naive(p1, p2, lengths1, lengths2, K: int) -> torch.Tensor:
def _knn_points_naive(
p1, p2, lengths1, lengths2, K: int, norm: int = 2
) -> torch.Tensor:
"""
Naive PyTorch implementation of K-Nearest Neighbors.
Returns always sorted results
@ -42,7 +44,12 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
pp1 = p1[n, :num1].view(num1, 1, D)
pp2 = p2[n, :num2].view(1, num2, D)
diff = pp1 - pp2
diff = (diff * diff).sum(2)
if norm == 2:
diff = (diff * diff).sum(2)
elif norm == 1:
diff = diff.abs().sum(2)
else:
raise ValueError("No support for norm %d" % (norm))
num2 = min(num2, K)
for i in range(num1):
dd = diff[i]
@ -59,9 +66,10 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
P1s = [8, 24]
P2s = [8, 16, 32]
Ks = [1, 3, 10]
norms = [1, 2]
versions = [0, 1, 2, 3]
factors = [Ns, Ds, P1s, P2s, Ks]
for N, D, P1, P2, K in product(*factors):
factors = [Ns, Ds, P1s, P2s, Ks, norms]
for N, D, P1, P2, K, norm in product(*factors):
for version in versions:
if version == 3 and K > 4:
continue
@ -73,9 +81,16 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
y_cuda.requires_grad_(True)
# forward
out1 = self._knn_points_naive(x, y, lengths1=None, lengths2=None, K=K)
out1 = self._knn_points_naive(
x, y, lengths1=None, lengths2=None, K=K, norm=norm
)
out2 = knn_points(
x_cuda, y_cuda, K=K, version=version, return_sorted=return_sorted
x_cuda,
y_cuda,
K=K,
norm=norm,
version=version,
return_sorted=return_sorted,
)
if K > 1 and not return_sorted:
# check out2 is not sorted
@ -121,8 +136,9 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
P1s = [8, 24]
P2s = [8, 16, 32]
Ks = [1, 3, 10]
factors = [Ns, Ds, P1s, P2s, Ks]
for N, D, P1, P2, K in product(*factors):
norms = [1, 2]
factors = [Ns, Ds, P1s, P2s, Ks, norms]
for N, D, P1, P2, K, norm in product(*factors):
x = torch.rand((N, P1, D), device=device, requires_grad=True)
y = torch.rand((N, P2, D), device=device, requires_grad=True)
lengths1 = torch.randint(low=1, high=P1, size=(N,), device=device)
@ -135,9 +151,11 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
# forward
out1 = self._knn_points_naive(
x, y, lengths1=lengths1, lengths2=lengths2, K=K
x, y, lengths1=lengths1, lengths2=lengths2, K=K, norm=norm
)
out2 = knn_points(
x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K, norm=norm
)
out2 = knn_points(x_csrc, y_csrc, lengths1=lengths1, lengths2=lengths2, K=K)
self.assertClose(out1[0], out2[0])
self.assertTrue(torch.all(out1[1] == out2[1]))
@ -198,6 +216,17 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
expected = all_expected[version]
self.assertEqual(actual, expected)
def test_invalid_norm(self):
device = get_random_cuda_device()
N, P1, P2, K, D = 4, 16, 12, 8, 3
x = torch.rand((N, P1, D), device=device)
y = torch.rand((N, P2, D), device=device)
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
knn_points(x, y, K=K, norm=3)
with self.assertRaisesRegex(ValueError, "Support for 1 or 2 norm."):
knn_points(x, y, K=K, norm=0)
@staticmethod
def knn_square(N: int, P1: int, P2: int, D: int, K: int, device: str):
device = torch.device(device)