diff --git a/pytorch3d/ops/knn.py b/pytorch3d/ops/knn.py index aff9fc6f..97e1191e 100644 --- a/pytorch3d/ops/knn.py +++ b/pytorch3d/ops/knn.py @@ -115,18 +115,18 @@ def knn_points( return_nn: If set to True returns the K nearest neighors in p2 for each point in p1. Returns: - p1_idx: LongTensor of shape (N, P1, K) giving the indices of the + dists: Tensor of shape (N, P1, K) giving the squared distances to + the nearest neighbors. This is padded with zeros both where a cloud in p2 + has fewer than K points and where a cloud in p1 has fewer than P1 points. + + idx: LongTensor of shape (N, P1, K) giving the indices of the K nearest neighbors from points in p1 to points in p2. Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points. - p1_dists: Tensor of shape (N, P1, K) giving the squared distances to - the nearest neighbors. This is padded with zeros both where a cloud in p2 - has fewer than K points and where a cloud in p1 has fewer than P1 points. - - p2_nn: Tensor of shape (N, P1, K, D) giving the K nearest neighbors in p2 for + nn: Tensor of shape (N, P1, K, D) giving the K nearest neighbors in p2 for each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th nearest neighbor for `p1[n, i]`. Returned if `return_nn` is True. The nearest neighbors are collected using `knn_gather` diff --git a/pytorch3d/ops/points_alignment.py b/pytorch3d/ops/points_alignment.py index 7ac3f182..eedf8d51 100644 --- a/pytorch3d/ops/points_alignment.py +++ b/pytorch3d/ops/points_alignment.py @@ -158,7 +158,7 @@ def iterative_closest_point( for iteration in range(max_iterations): Xt_nn_points = knn_points( Xt, Yt, lengths1=num_points_X, lengths2=num_points_Y, K=1, return_nn=True - )[2][:, :, 0, :] + ).knn[:, :, 0, :] # get the alignment of the nearest neighbors from Yt with Xt_init R, T, s = corresponding_points_alignment( diff --git a/pytorch3d/ops/utils.py b/pytorch3d/ops/utils.py index 41e00ace..ea4980be 100644 --- a/pytorch3d/ops/utils.py +++ b/pytorch3d/ops/utils.py @@ -126,14 +126,14 @@ def get_point_covariances( of shape `(minibatch, num_points, neighborhood_size, dim)`. """ # get K nearest neighbor idx for each point in the point cloud - _, _, k_nearest_neighbors = knn_points( + k_nearest_neighbors = knn_points( points_padded, points_padded, lengths1=num_points_per_cloud, lengths2=num_points_per_cloud, K=neighborhood_size, return_nn=True, - ) + ).knn # obtain the mean of the neighborhood pt_mean = k_nearest_neighbors.mean(2, keepdim=True) # compute the diff of the neighborhood and the mean of the neighborhood