diff --git a/pytorch3d/loss/chamfer.py b/pytorch3d/loss/chamfer.py index 58ac3daa..f37f5938 100644 --- a/pytorch3d/loss/chamfer.py +++ b/pytorch3d/loss/chamfer.py @@ -145,11 +145,11 @@ def chamfer_distance( cham_norm_x = x.new_zeros(()) cham_norm_y = x.new_zeros(()) - x_dists, x_idx = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1) - y_dists, y_idx = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1) + x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1) + y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1) - cham_x = x_dists[..., 0] # (N, P1) - cham_y = y_dists[..., 0] # (N, P2) + cham_x = x_nn.dists[..., 0] # (N, P1) + cham_y = y_nn.dists[..., 0] # (N, P2) if is_x_heterogeneous: cham_x[x_mask] = 0.0 @@ -162,8 +162,8 @@ def chamfer_distance( if return_normals: # Gather the normals using the indices and keep only value for k=0 - x_normals_near = knn_gather(y_normals, x_idx, y_lengths)[..., 0, :] - y_normals_near = knn_gather(x_normals, y_idx, x_lengths)[..., 0, :] + x_normals_near = knn_gather(y_normals, x_nn.idx, y_lengths)[..., 0, :] + y_normals_near = knn_gather(x_normals, y_nn.idx, x_lengths)[..., 0, :] cham_norm_x = 1 - torch.abs( F.cosine_similarity(x_normals, x_normals_near, dim=2, eps=1e-6)