pytorch3d/pytorch3d/csrc/ball_query/ball_query_cpu.cpp
Nicholas Ormrod 20bd8b33f6 facebook-unused-include-check in fbcode/vision
Summary:
Remove headers flagged by facebook-unused-include-check over fbcode.vision.

+ format and autodeps

This is a codemod. It was automatically generated and will be landed once it is approved and tests are passing in sandcastle.
You have been added as a reviewer by Sentinel or Butterfly.

Autodiff project: uiv
Autodiff partition: fbcode.vision
Autodiff bookmark: ad.uiv.fbcode.vision

Reviewed By: dtolnay

Differential Revision: D70403619

fbshipit-source-id: d109c15774eeb3d809875f75fa2a26ed20d7f9a6
2025-02-28 18:08:12 -08:00

55 lines
1.6 KiB
C++

/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <torch/extension.h>
#include <tuple>
std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
const at::Tensor& p1,
const at::Tensor& p2,
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
float radius) {
const int N = p1.size(0);
const int P1 = p1.size(1);
const int D = p1.size(2);
auto long_opts = lengths1.options().dtype(torch::kInt64);
torch::Tensor idxs = torch::full({N, P1, K}, -1, long_opts);
torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
const float radius2 = radius * radius;
auto p1_a = p1.accessor<float, 3>();
auto p2_a = p2.accessor<float, 3>();
auto lengths1_a = lengths1.accessor<int64_t, 1>();
auto lengths2_a = lengths2.accessor<int64_t, 1>();
auto idxs_a = idxs.accessor<int64_t, 3>();
auto dists_a = dists.accessor<float, 3>();
for (int n = 0; n < N; ++n) {
const int64_t length1 = lengths1_a[n];
const int64_t length2 = lengths2_a[n];
for (int64_t i = 0; i < length1; ++i) {
for (int64_t j = 0, count = 0; j < length2 && count < K; ++j) {
float dist2 = 0;
for (int d = 0; d < D; ++d) {
float diff = p1_a[n][i][d] - p2_a[n][j][d];
dist2 += diff * diff;
}
if (dist2 < radius2) {
dists_a[n][i][count] = dist2;
idxs_a[n][i][count] = j;
++count;
}
}
}
}
return std::make_tuple(idxs, dists);
}