add min_triangle_area argument to IsInsideTriangle

Summary:
1. changed IsInsideTriangle in geometry_utils to take in min_triangle_area parameter instead of hardcoded value
2. updated point_mesh_cpu.cpp and point_mesh_cuda.[h/cu] to adapt to changes in geometry_utils function signatures
3. updated point_mesh_distance.py and test_point_mesh_distance.py to modify _C. calls

Reviewed By: bottler

Differential Revision: D34459764

fbshipit-source-id: 0549e78713c6d68f03d85fb597a13dd88e09b686
This commit is contained in:
Winnie Lin 2022-02-25 12:43:04 -08:00 committed by Facebook GitHub Bot
parent 4d043fc9ac
commit 471b126818
7 changed files with 344 additions and 134 deletions

View File

@ -57,29 +57,33 @@ void IncrementPoint(at::TensorAccessor<T, 1>&& t, const vec3<T>& point) {
template <typename T>
T HullDistance(
const std::array<vec3<T>, 1>& a,
const std::array<vec3<T>, 2>& b) {
const std::array<vec3<T>, 2>& b,
const double /*min_triangle_area*/) {
using std::get;
return PointLine3DistanceForward(get<0>(a), get<0>(b), get<1>(b));
}
template <typename T>
T HullDistance(
const std::array<vec3<T>, 1>& a,
const std::array<vec3<T>, 3>& b) {
const std::array<vec3<T>, 3>& b,
const double min_triangle_area) {
using std::get;
return PointTriangle3DistanceForward(
get<0>(a), get<0>(b), get<1>(b), get<2>(b));
get<0>(a), get<0>(b), get<1>(b), get<2>(b), min_triangle_area);
}
template <typename T>
T HullDistance(
const std::array<vec3<T>, 2>& a,
const std::array<vec3<T>, 1>& b) {
return HullDistance(b, a);
const std::array<vec3<T>, 1>& b,
const double /*min_triangle_area*/) {
return HullDistance(b, a, 1);
}
template <typename T>
T HullDistance(
const std::array<vec3<T>, 3>& a,
const std::array<vec3<T>, 1>& b) {
return HullDistance(b, a);
const std::array<vec3<T>, 1>& b,
const double min_triangle_area) {
return HullDistance(b, a, min_triangle_area);
}
template <typename T>
@ -88,7 +92,8 @@ void HullHullDistanceBackward(
const std::array<vec3<T>, 2>& b,
T grad_dist,
at::TensorAccessor<T, 1>&& grad_a,
at::TensorAccessor<T, 2>&& grad_b) {
at::TensorAccessor<T, 2>&& grad_b,
const double /*min_triangle_area*/) {
using std::get;
auto res =
PointLine3DistanceBackward(get<0>(a), get<0>(b), get<1>(b), grad_dist);
@ -102,10 +107,11 @@ void HullHullDistanceBackward(
const std::array<vec3<T>, 3>& b,
T grad_dist,
at::TensorAccessor<T, 1>&& grad_a,
at::TensorAccessor<T, 2>&& grad_b) {
at::TensorAccessor<T, 2>&& grad_b,
const double min_triangle_area) {
using std::get;
auto res = PointTriangle3DistanceBackward(
get<0>(a), get<0>(b), get<1>(b), get<2>(b), grad_dist);
get<0>(a), get<0>(b), get<1>(b), get<2>(b), grad_dist, min_triangle_area);
IncrementPoint(std::move(grad_a), get<0>(res));
IncrementPoint(grad_b[0], get<1>(res));
IncrementPoint(grad_b[1], get<2>(res));
@ -117,9 +123,10 @@ void HullHullDistanceBackward(
const std::array<vec3<T>, 1>& b,
T grad_dist,
at::TensorAccessor<T, 2>&& grad_a,
at::TensorAccessor<T, 1>&& grad_b) {
at::TensorAccessor<T, 1>&& grad_b,
const double min_triangle_area) {
return HullHullDistanceBackward(
b, a, grad_dist, std::move(grad_b), std::move(grad_a));
b, a, grad_dist, std::move(grad_b), std::move(grad_a), min_triangle_area);
}
template <typename T>
void HullHullDistanceBackward(
@ -127,9 +134,10 @@ void HullHullDistanceBackward(
const std::array<vec3<T>, 1>& b,
T grad_dist,
at::TensorAccessor<T, 2>&& grad_a,
at::TensorAccessor<T, 1>&& grad_b) {
at::TensorAccessor<T, 1>&& grad_b,
const double /*min_triangle_area*/) {
return HullHullDistanceBackward(
b, a, grad_dist, std::move(grad_b), std::move(grad_a));
b, a, grad_dist, std::move(grad_b), std::move(grad_a), 1);
}
template <int H>
@ -150,7 +158,8 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceForwardCpu(
const at::Tensor& as,
const at::Tensor& as_first_idx,
const at::Tensor& bs,
const at::Tensor& bs_first_idx) {
const at::Tensor& bs_first_idx,
const double min_triangle_area) {
const int64_t A_N = as.size(0);
const int64_t B_N = bs.size(0);
const int64_t BATCHES = as_first_idx.size(0);
@ -190,7 +199,8 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceForwardCpu(
size_t min_idx = 0;
auto a = ExtractHull<H1>(as_a[a_n]);
for (int64_t b_n = b_batch_start; b_n < b_batch_end; ++b_n) {
float dist = HullDistance(a, ExtractHull<H2>(bs_a[b_n]));
float dist =
HullDistance(a, ExtractHull<H2>(bs_a[b_n]), min_triangle_area);
if (dist <= min_dist) {
min_dist = dist;
min_idx = b_n;
@ -208,7 +218,8 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceBackwardCpu(
const at::Tensor& as,
const at::Tensor& bs,
const at::Tensor& idx_bs,
const at::Tensor& grad_dists) {
const at::Tensor& grad_dists,
const double min_triangle_area) {
const int64_t A_N = as.size(0);
TORCH_CHECK(idx_bs.size(0) == A_N);
@ -230,7 +241,12 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceBackwardCpu(
auto a = ExtractHull<H1>(as_a[a_n]);
auto b = ExtractHull<H2>(bs_a[idx_bs_a[a_n]]);
HullHullDistanceBackward(
a, b, grad_dists_a[a_n], grad_as_a[a_n], grad_bs_a[idx_bs_a[a_n]]);
a,
b,
grad_dists_a[a_n],
grad_as_a[a_n],
grad_bs_a[idx_bs_a[a_n]],
min_triangle_area);
}
return std::make_tuple(grad_as, grad_bs);
}
@ -238,7 +254,8 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceBackwardCpu(
template <int H>
torch::Tensor PointHullArrayDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& bs) {
const torch::Tensor& bs,
const double min_triangle_area) {
const int64_t P = points.size(0);
const int64_t B_N = bs.size(0);
@ -254,7 +271,7 @@ torch::Tensor PointHullArrayDistanceForwardCpu(
auto dest = dists_a[p];
for (int64_t b_n = 0; b_n < B_N; ++b_n) {
auto b = ExtractHull<H>(bs_a[b_n]);
dest[b_n] = HullDistance(point, b);
dest[b_n] = HullDistance(point, b, min_triangle_area);
}
}
return dists;
@ -264,7 +281,8 @@ template <int H>
std::tuple<at::Tensor, at::Tensor> PointHullArrayDistanceBackwardCpu(
const at::Tensor& points,
const at::Tensor& bs,
const at::Tensor& grad_dists) {
const at::Tensor& grad_dists,
const double min_triangle_area) {
const int64_t P = points.size(0);
const int64_t B_N = bs.size(0);
@ -287,7 +305,12 @@ std::tuple<at::Tensor, at::Tensor> PointHullArrayDistanceBackwardCpu(
for (int64_t b_n = 0; b_n < B_N; ++b_n) {
auto b = ExtractHull<H>(bs_a[b_n]);
HullHullDistanceBackward(
point, b, grad_dist[b_n], std::move(grad_point), grad_bs_a[b_n]);
point,
b,
grad_dist[b_n],
std::move(grad_point),
grad_bs_a[b_n],
min_triangle_area);
}
}
return std::make_tuple(grad_points, grad_bs);
@ -299,63 +322,70 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx) {
const torch::Tensor& tris_first_idx,
const double min_triangle_area) {
return HullHullDistanceForwardCpu<1, 3>(
points, points_first_idx, tris, tris_first_idx);
points, points_first_idx, tris, tris_first_idx, min_triangle_area);
}
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists) {
const torch::Tensor& grad_dists,
const double min_triangle_area) {
return HullHullDistanceBackwardCpu<1, 3>(
points, tris, idx_points, grad_dists);
points, tris, idx_points, grad_dists, min_triangle_area);
}
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx) {
const torch::Tensor& tris_first_idx,
const double min_triangle_area) {
return HullHullDistanceForwardCpu<3, 1>(
tris, tris_first_idx, points, points_first_idx);
tris, tris_first_idx, points, points_first_idx, min_triangle_area);
}
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_tris,
const torch::Tensor& grad_dists) {
auto res =
HullHullDistanceBackwardCpu<3, 1>(tris, points, idx_tris, grad_dists);
const torch::Tensor& grad_dists,
const double min_triangle_area) {
auto res = HullHullDistanceBackwardCpu<3, 1>(
tris, points, idx_tris, grad_dists, min_triangle_area);
return std::make_tuple(std::get<1>(res), std::get<0>(res));
}
torch::Tensor PointEdgeArrayDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& segms) {
return PointHullArrayDistanceForwardCpu<2>(points, segms);
return PointHullArrayDistanceForwardCpu<2>(points, segms, 1);
}
std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCpu(
const at::Tensor& points,
const at::Tensor& tris,
const at::Tensor& grad_dists) {
return PointHullArrayDistanceBackwardCpu<3>(points, tris, grad_dists);
const at::Tensor& grad_dists,
const double min_triangle_area) {
return PointHullArrayDistanceBackwardCpu<3>(
points, tris, grad_dists, min_triangle_area);
}
torch::Tensor PointFaceArrayDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& tris) {
return PointHullArrayDistanceForwardCpu<3>(points, tris);
const torch::Tensor& tris,
const double min_triangle_area) {
return PointHullArrayDistanceForwardCpu<3>(points, tris, min_triangle_area);
}
std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCpu(
const at::Tensor& points,
const at::Tensor& segms,
const at::Tensor& grad_dists) {
return PointHullArrayDistanceBackwardCpu<2>(points, segms, grad_dists);
return PointHullArrayDistanceBackwardCpu<2>(points, segms, grad_dists, 1);
}
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCpu(
@ -365,7 +395,7 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCpu(
const torch::Tensor& segms_first_idx,
const int64_t /*max_points*/) {
return HullHullDistanceForwardCpu<1, 2>(
points, points_first_idx, segms, segms_first_idx);
points, points_first_idx, segms, segms_first_idx, 1);
}
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCpu(
@ -374,7 +404,7 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCpu(
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists) {
return HullHullDistanceBackwardCpu<1, 2>(
points, segms, idx_points, grad_dists);
points, segms, idx_points, grad_dists, 1);
}
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCpu(
@ -384,7 +414,7 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCpu(
const torch::Tensor& segms_first_idx,
const int64_t /*max_segms*/) {
return HullHullDistanceForwardCpu<2, 1>(
segms, segms_first_idx, points, points_first_idx);
segms, segms_first_idx, points, points_first_idx, 1);
}
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCpu(
@ -392,7 +422,7 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCpu(
const torch::Tensor& segms,
const torch::Tensor& idx_segms,
const torch::Tensor& grad_dists) {
auto res =
HullHullDistanceBackwardCpu<2, 1>(segms, points, idx_segms, grad_dists);
auto res = HullHullDistanceBackwardCpu<2, 1>(
segms, points, idx_segms, grad_dists, 1);
return std::make_tuple(std::get<1>(res), std::get<0>(res));
}

View File

@ -32,7 +32,8 @@ __global__ void DistanceForwardKernel(
const int64_t* __restrict__ targets_first_idx, // (B,)
const size_t batch_size, // B
float* __restrict__ dist_objects, // (O,)
int64_t* __restrict__ idx_objects) { // (O,)
int64_t* __restrict__ idx_objects, // (O,)
const double min_triangle_area) {
// This kernel is used interchangeably to compute bi-directional distances
// between points and triangles/lines. The direction of the distance computed,
// i.e. point to triangle/line or triangle/line to point, depends on the order
@ -93,7 +94,8 @@ __global__ void DistanceForwardKernel(
points_f3[point_idx],
face_f3[face_idx],
face_f3[face_idx + 1],
face_f3[face_idx + 2]);
face_f3[face_idx + 2],
min_triangle_area);
} else {
dist = PointLine3DistanceForward(
points_f3[point_idx], face_f3[face_idx], face_f3[face_idx + 1]);
@ -138,7 +140,8 @@ std::tuple<at::Tensor, at::Tensor> DistanceForwardCuda(
const at::Tensor& targets,
const size_t targets_dim,
const at::Tensor& targets_first_idx,
const int64_t max_objects) {
const int64_t max_objects,
const double min_triangle_area) {
// Check inputs are on the same device
at::TensorArg objects_t{objects, "objects", 1},
objects_first_idx_t{objects_first_idx, "objects_first_idx", 2},
@ -202,7 +205,8 @@ std::tuple<at::Tensor, at::Tensor> DistanceForwardCuda(
targets_first_idx.contiguous().data_ptr<int64_t>(),
batch_size,
dists.data_ptr<float>(),
idxs.data_ptr<int64_t>());
idxs.data_ptr<int64_t>(),
min_triangle_area);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(dists, idxs);
@ -217,7 +221,8 @@ __global__ void DistanceBackwardKernel(
const int64_t* __restrict__ idx_objects, // (O,)
const float* __restrict__ grad_dists, // (O,)
float* __restrict__ grad_points, // ((O or T) * 3)
float* __restrict__ grad_face) { // ((O or T) * max(oD, tD) * 3)
float* __restrict__ grad_face, // ((O or T) * max(oD, tD) * 3)
const double min_triangle_area) {
// This kernel is used interchangeably to compute bi-directional backward
// distances between points and triangles/lines. The direction of the distance
// computed, i.e. point to triangle/line or triangle/line to point, depends on
@ -247,7 +252,8 @@ __global__ void DistanceBackwardKernel(
face_f3[face_index],
face_f3[face_index + 1],
face_f3[face_index + 2],
grad_dists[o]);
grad_dists[o],
min_triangle_area);
grad_point = thrust::get<0>(grads);
grad_v0 = thrust::get<1>(grads);
grad_v1 = thrust::get<2>(grads);
@ -289,7 +295,8 @@ std::tuple<at::Tensor, at::Tensor> DistanceBackwardCuda(
const at::Tensor& targets,
const size_t targets_dim,
const at::Tensor& idx_objects,
const at::Tensor& grad_dists) {
const at::Tensor& grad_dists,
const double min_triangle_area) {
// Check inputs are on the same device
at::TensorArg objects_t{objects, "objects", 1},
targets_t{targets, "targets", 2},
@ -355,7 +362,8 @@ std::tuple<at::Tensor, at::Tensor> DistanceBackwardCuda(
idx_objects.contiguous().data_ptr<int64_t>(),
grad_dists.contiguous().data_ptr<float>(),
grad_points.data_ptr<float>(),
grad_tris.data_ptr<float>());
grad_tris.data_ptr<float>(),
min_triangle_area);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_tris);
@ -370,17 +378,27 @@ std::tuple<at::Tensor, at::Tensor> PointFaceDistanceForwardCuda(
const at::Tensor& points_first_idx,
const at::Tensor& tris,
const at::Tensor& tris_first_idx,
const int64_t max_points) {
const int64_t max_points,
const double min_triangle_area) {
return DistanceForwardCuda(
points, 1, points_first_idx, tris, 3, tris_first_idx, max_points);
points,
1,
points_first_idx,
tris,
3,
tris_first_idx,
max_points,
min_triangle_area);
}
std::tuple<at::Tensor, at::Tensor> PointFaceDistanceBackwardCuda(
const at::Tensor& points,
const at::Tensor& tris,
const at::Tensor& idx_points,
const at::Tensor& grad_dists) {
return DistanceBackwardCuda(points, 1, tris, 3, idx_points, grad_dists);
const at::Tensor& grad_dists,
const double min_triangle_area) {
return DistanceBackwardCuda(
points, 1, tris, 3, idx_points, grad_dists, min_triangle_area);
}
// ****************************************************************************
@ -392,17 +410,27 @@ std::tuple<at::Tensor, at::Tensor> FacePointDistanceForwardCuda(
const at::Tensor& points_first_idx,
const at::Tensor& tris,
const at::Tensor& tris_first_idx,
const int64_t max_tris) {
const int64_t max_tris,
const double min_triangle_area) {
return DistanceForwardCuda(
tris, 3, tris_first_idx, points, 1, points_first_idx, max_tris);
tris,
3,
tris_first_idx,
points,
1,
points_first_idx,
max_tris,
min_triangle_area);
}
std::tuple<at::Tensor, at::Tensor> FacePointDistanceBackwardCuda(
const at::Tensor& points,
const at::Tensor& tris,
const at::Tensor& idx_tris,
const at::Tensor& grad_dists) {
return DistanceBackwardCuda(tris, 3, points, 1, idx_tris, grad_dists);
const at::Tensor& grad_dists,
const double min_triangle_area) {
return DistanceBackwardCuda(
tris, 3, points, 1, idx_tris, grad_dists, min_triangle_area);
}
// ****************************************************************************
@ -416,7 +444,14 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceForwardCuda(
const at::Tensor& segms_first_idx,
const int64_t max_points) {
return DistanceForwardCuda(
points, 1, points_first_idx, segms, 2, segms_first_idx, max_points);
points,
1,
points_first_idx,
segms,
2,
segms_first_idx,
max_points,
1); // todo: unused parameter handling for min_triangle_area
}
std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceBackwardCuda(
@ -424,7 +459,7 @@ std::tuple<at::Tensor, at::Tensor> PointEdgeDistanceBackwardCuda(
const at::Tensor& segms,
const at::Tensor& idx_points,
const at::Tensor& grad_dists) {
return DistanceBackwardCuda(points, 1, segms, 2, idx_points, grad_dists);
return DistanceBackwardCuda(points, 1, segms, 2, idx_points, grad_dists, 1);
}
// ****************************************************************************
@ -438,7 +473,7 @@ std::tuple<at::Tensor, at::Tensor> EdgePointDistanceForwardCuda(
const at::Tensor& segms_first_idx,
const int64_t max_segms) {
return DistanceForwardCuda(
segms, 2, segms_first_idx, points, 1, points_first_idx, max_segms);
segms, 2, segms_first_idx, points, 1, points_first_idx, max_segms, 1);
}
std::tuple<at::Tensor, at::Tensor> EdgePointDistanceBackwardCuda(
@ -446,7 +481,7 @@ std::tuple<at::Tensor, at::Tensor> EdgePointDistanceBackwardCuda(
const at::Tensor& segms,
const at::Tensor& idx_segms,
const at::Tensor& grad_dists) {
return DistanceBackwardCuda(segms, 2, points, 1, idx_segms, grad_dists);
return DistanceBackwardCuda(segms, 2, points, 1, idx_segms, grad_dists, 1);
}
// ****************************************************************************
@ -459,7 +494,8 @@ __global__ void PointFaceArrayForwardKernel(
const float* __restrict__ tris, // (T, 3, 3)
float* __restrict__ dists, // (P, T)
const size_t P,
const size_t T) {
const size_t T,
const double min_triangle_area) {
const float3* points_f3 = (float3*)points;
const float3* tris_f3 = (float3*)tris;
@ -475,14 +511,16 @@ __global__ void PointFaceArrayForwardKernel(
const float3 v2 = tris_f3[t * 3 + 2];
const float3 point = points_f3[p];
float dist = PointTriangle3DistanceForward(point, v0, v1, v2);
float dist =
PointTriangle3DistanceForward(point, v0, v1, v2, min_triangle_area);
dists[p * T + t] = dist;
}
}
at::Tensor PointFaceArrayDistanceForwardCuda(
const at::Tensor& points,
const at::Tensor& tris) {
const at::Tensor& tris,
const double min_triangle_area) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1}, tris_t{tris, "tris", 2};
at::CheckedFrom c = "PointFaceArrayDistanceForwardCuda";
@ -516,7 +554,8 @@ at::Tensor PointFaceArrayDistanceForwardCuda(
tris.contiguous().data_ptr<float>(),
dists.data_ptr<float>(),
P,
T);
T,
min_triangle_area);
AT_CUDA_CHECK(cudaGetLastError());
return dists;
@ -529,7 +568,8 @@ __global__ void PointFaceArrayBackwardKernel(
float* __restrict__ grad_points, // (P, 3)
float* __restrict__ grad_tris, // (T, 3, 3)
const size_t P,
const size_t T) {
const size_t T,
const double min_triangle_area) {
const float3* points_f3 = (float3*)points;
const float3* tris_f3 = (float3*)tris;
@ -547,8 +587,8 @@ __global__ void PointFaceArrayBackwardKernel(
const float3 point = points_f3[p];
const float grad_dist = grad_dists[p * T + t];
const auto grad =
PointTriangle3DistanceBackward(point, v0, v1, v2, grad_dist);
const auto grad = PointTriangle3DistanceBackward(
point, v0, v1, v2, grad_dist, min_triangle_area);
const float3 grad_point = thrust::get<0>(grad);
const float3 grad_v0 = thrust::get<1>(grad);
@ -576,7 +616,8 @@ __global__ void PointFaceArrayBackwardKernel(
std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCuda(
const at::Tensor& points,
const at::Tensor& tris,
const at::Tensor& grad_dists) {
const at::Tensor& grad_dists,
const double min_triangle_area) {
// Check inputs are on the same device
at::TensorArg points_t{points, "points", 1}, tris_t{tris, "tris", 2},
grad_dists_t{grad_dists, "grad_dists", 3};
@ -615,7 +656,8 @@ std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCuda(
grad_points.data_ptr<float>(),
grad_tris.data_ptr<float>(),
P,
T);
T,
min_triangle_area);
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_points, grad_tris);

View File

@ -31,6 +31,8 @@
// max_points: Scalar equal to max(P_i) for i in [0, N - 1] containing
// the maximum number of points in the batch and is used to set
// the block dimensions in the CUDA implementation.
// min_triangle_area: triangles less than this size are considered
// points/lines.
//
// Returns:
// dists: FloatTensor of shape (P,), where dists[p] is the minimum
@ -51,21 +53,24 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForwardCuda(
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx,
const int64_t max_points);
const int64_t max_points,
const double min_triangle_area);
#endif
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx);
const torch::Tensor& tris_first_idx,
const double min_triangle_area);
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForward(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx,
const int64_t max_points) {
const int64_t max_points,
const double min_triangle_area) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(points);
@ -73,13 +78,18 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForward(
CHECK_CUDA(tris);
CHECK_CUDA(tris_first_idx);
return PointFaceDistanceForwardCuda(
points, points_first_idx, tris, tris_first_idx, max_points);
points,
points_first_idx,
tris,
tris_first_idx,
max_points,
min_triangle_area);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return PointFaceDistanceForwardCpu(
points, points_first_idx, tris, tris_first_idx);
points, points_first_idx, tris, tris_first_idx, min_triangle_area);
}
// Backward pass for PointFaceDistance.
@ -91,6 +101,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForward(
// of the closest face in the example in the batch.
// This is computed by the forward pass
// grad_dists: FloatTensor of shape (P,)
// min_triangle_area: triangles less than this size are considered
// points/lines.
//
// Returns:
// grad_points: FloatTensor of shape (P, 3)
@ -103,31 +115,36 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists);
const torch::Tensor& grad_dists,
const double min_triangle_area);
#endif
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists);
const torch::Tensor& grad_dists,
const double min_triangle_area);
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackward(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists) {
const torch::Tensor& grad_dists,
const double min_triangle_area) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(points);
CHECK_CUDA(tris);
CHECK_CUDA(idx_points);
CHECK_CUDA(grad_dists);
return PointFaceDistanceBackwardCuda(points, tris, idx_points, grad_dists);
return PointFaceDistanceBackwardCuda(
points, tris, idx_points, grad_dists, min_triangle_area);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return PointFaceDistanceBackwardCpu(points, tris, idx_points, grad_dists);
return PointFaceDistanceBackwardCpu(
points, tris, idx_points, grad_dists, min_triangle_area);
}
// ****************************************************************************
@ -148,6 +165,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackward(
// max_tris: Scalar equal to max(T_i) for i in [0, N - 1] containing
// the maximum number of faces in the batch and is used to set
// the block dimensions in the CUDA implementation.
// min_triangle_area: triangles less than this size are considered
// points/lines.
//
// Returns:
// dists: FloatTensor of shape (T,), where dists[t] is the minimum squared
@ -167,21 +186,24 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForwardCuda(
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx,
const int64_t max_tris);
const int64_t max_tris,
const double min_triangle_area);
#endif
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx);
const torch::Tensor& tris_first_idx,
const double min_triangle_area);
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForward(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx,
const int64_t max_tris) {
const int64_t max_tris,
const double min_triangle_area) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(points);
@ -189,13 +211,18 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForward(
CHECK_CUDA(tris);
CHECK_CUDA(tris_first_idx);
return FacePointDistanceForwardCuda(
points, points_first_idx, tris, tris_first_idx, max_tris);
points,
points_first_idx,
tris,
tris_first_idx,
max_tris,
min_triangle_area);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return FacePointDistanceForwardCpu(
points, points_first_idx, tris, tris_first_idx);
points, points_first_idx, tris, tris_first_idx, min_triangle_area);
}
// Backward pass for FacePointDistance.
@ -207,6 +234,8 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForward(
// of the closest point in the example in the batch.
// This is computed by the forward pass
// grad_dists: FloatTensor of shape (T,)
// min_triangle_area: triangles less than this size are considered
// points/lines.
//
// Returns:
// grad_points: FloatTensor of shape (P, 3)
@ -219,32 +248,37 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_tris,
const torch::Tensor& grad_dists);
const torch::Tensor& grad_dists,
const double min_triangle_area);
#endif
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_tris,
const torch::Tensor& grad_dists);
const torch::Tensor& grad_dists,
const double min_triangle_area);
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackward(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_tris,
const torch::Tensor& grad_dists) {
const torch::Tensor& grad_dists,
const double min_triangle_area) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(points);
CHECK_CUDA(tris);
CHECK_CUDA(idx_tris);
CHECK_CUDA(grad_dists);
return FacePointDistanceBackwardCuda(points, tris, idx_tris, grad_dists);
return FacePointDistanceBackwardCuda(
points, tris, idx_tris, grad_dists, min_triangle_area);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return FacePointDistanceBackwardCpu(points, tris, idx_tris, grad_dists);
return FacePointDistanceBackwardCpu(
points, tris, idx_tris, grad_dists, min_triangle_area);
}
// ****************************************************************************
@ -494,6 +528,8 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
// points: FloatTensor of shape (P, 3)
// tris: FloatTensor of shape (T, 3, 3) of the triangular faces. The t-th
// triangular face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2])
// min_triangle_area: triangles less than this size are considered
// points/lines.
//
// Returns:
// dists: FloatTensor of shape (P, T), where dists[p, t] is the squared
@ -509,26 +545,29 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
torch::Tensor PointFaceArrayDistanceForwardCuda(
const torch::Tensor& points,
const torch::Tensor& tris);
const torch::Tensor& tris,
const double min_triangle_area);
#endif
torch::Tensor PointFaceArrayDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& tris);
const torch::Tensor& tris,
const double min_triangle_area);
torch::Tensor PointFaceArrayDistanceForward(
const torch::Tensor& points,
const torch::Tensor& tris) {
const torch::Tensor& tris,
const double min_triangle_area) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(points);
CHECK_CUDA(tris);
return PointFaceArrayDistanceForwardCuda(points, tris);
return PointFaceArrayDistanceForwardCuda(points, tris, min_triangle_area);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return PointFaceArrayDistanceForwardCpu(points, tris);
return PointFaceArrayDistanceForwardCpu(points, tris, min_triangle_area);
}
// Backward pass for PointFaceArrayDistance.
@ -537,6 +576,8 @@ torch::Tensor PointFaceArrayDistanceForward(
// points: FloatTensor of shape (P, 3)
// tris: FloatTensor of shape (T, 3, 3)
// grad_dists: FloatTensor of shape (P, T)
// min_triangle_area: triangles less than this size are considered
// points/lines.
//
// Returns:
// grad_points: FloatTensor of shape (P, 3)
@ -547,28 +588,33 @@ torch::Tensor PointFaceArrayDistanceForward(
std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackwardCuda(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& grad_dists);
const torch::Tensor& grad_dists,
const double min_triangle_area);
#endif
std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& grad_dists);
const torch::Tensor& grad_dists,
const double min_triangle_area);
std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackward(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& grad_dists) {
const torch::Tensor& grad_dists,
const double min_triangle_area) {
if (points.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(points);
CHECK_CUDA(tris);
CHECK_CUDA(grad_dists);
return PointFaceArrayDistanceBackwardCuda(points, tris, grad_dists);
return PointFaceArrayDistanceBackwardCuda(
points, tris, grad_dists, min_triangle_area);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
return PointFaceArrayDistanceBackwardCpu(points, tris, grad_dists);
return PointFaceArrayDistanceBackwardCpu(
points, tris, grad_dists, min_triangle_area);
}
// ****************************************************************************

View File

@ -540,6 +540,8 @@ __device__ inline float3 BarycentricCoords3Forward(
// Args:
// p: vec3 coordinates of a point
// v0, v1, v2: vec3 coordinates of the triangle vertices
// min_triangle_area: triangles less than this size are considered
// points/lines, IsInsideTriangle returns False
//
// Returns:
// inside: bool indicating wether p is inside triangle
@ -548,9 +550,10 @@ __device__ inline bool IsInsideTriangle(
const float3& p,
const float3& v0,
const float3& v1,
const float3& v2) {
const float3& v2,
const double min_triangle_area) {
bool inside;
if (AreaOfTriangle(v0, v1, v2) < 5e-3) {
if (AreaOfTriangle(v0, v1, v2) < min_triangle_area) {
inside = 0;
} else {
float3 bary = BarycentricCoords3Forward(p, v0, v1, v2);
@ -660,6 +663,8 @@ PointLine3DistanceBackward(
// Args:
// p: vec3 coordinates of a point
// v0, v1, v2: vec3 coordinates of the triangle vertices
// min_triangle_area: triangles less than this size are considered
// points/lines, IsInsideTriangle returns False
//
// Returns:
// dist: Float of the squared distance
@ -669,7 +674,8 @@ __device__ inline float PointTriangle3DistanceForward(
const float3& p,
const float3& v0,
const float3& v1,
const float3& v2) {
const float3& v2,
const double min_triangle_area) {
float3 normal = cross(v2 - v0, v1 - v0);
const float norm_normal = norm(normal);
normal = normalize(normal);
@ -679,7 +685,7 @@ __device__ inline float PointTriangle3DistanceForward(
const float t = dot(v0 - p, normal);
const float3 p0 = p + t * normal;
bool is_inside = IsInsideTriangle(p0, v0, v1, v2);
bool is_inside = IsInsideTriangle(p0, v0, v1, v2, min_triangle_area);
float dist = 0.0f;
if ((is_inside) && (norm_normal > kEpsilon)) {
@ -705,6 +711,8 @@ __device__ inline float PointTriangle3DistanceForward(
// p: xyz coordinates of a point
// v0, v1, v2: xyz coordinates of the triangle vertices
// grad_dist: Float of the gradient wrt dist
// min_triangle_area: triangles less than this size are considered
// points/lines, IsInsideTriangle returns False
//
// Returns:
// tuple of gradients for the point and triangle:
@ -717,7 +725,8 @@ PointTriangle3DistanceBackward(
const float3& v0,
const float3& v1,
const float3& v2,
const float& grad_dist) {
const float& grad_dist,
const double min_triangle_area) {
const float3 v2v0 = v2 - v0;
const float3 v1v0 = v1 - v0;
const float3 v0p = v0 - p;
@ -731,7 +740,7 @@ PointTriangle3DistanceBackward(
const float3 p0 = p + t * normal;
const float3 diff = t * normal;
bool is_inside = IsInsideTriangle(p0, v0, v1, v2);
bool is_inside = IsInsideTriangle(p0, v0, v1, v2, min_triangle_area);
float3 grad_p = make_float3(0.0f, 0.0f, 0.0f);
float3 grad_v0 = make_float3(0.0f, 0.0f, 0.0f);

View File

@ -640,6 +640,8 @@ vec3<T> BarycentricCoords3Forward(
// Args:
// p: vec3 coordinates of a point
// v0, v1, v2: vec3 coordinates of the triangle vertices
// min_triangle_area: triangles less than this size are considered
// points/lines, IsInsideTriangle returns False
//
// Returns:
// inside: bool indicating wether p is inside triangle
@ -649,9 +651,10 @@ static bool IsInsideTriangle(
const vec3<T>& p,
const vec3<T>& v0,
const vec3<T>& v1,
const vec3<T>& v2) {
const vec3<T>& v2,
const double min_triangle_area) {
bool inside;
if (AreaOfTriangle(v0, v1, v2) < 5e-3) {
if (AreaOfTriangle(v0, v1, v2) < min_triangle_area) {
inside = 0;
} else {
vec3<T> bary = BarycentricCoords3Forward(p, v0, v1, v2);
@ -668,7 +671,8 @@ T PointTriangle3DistanceForward(
const vec3<T>& p,
const vec3<T>& v0,
const vec3<T>& v1,
const vec3<T>& v2) {
const vec3<T>& v2,
const double min_triangle_area) {
vec3<T> normal = cross(v2 - v0, v1 - v0);
const T norm_normal = norm(normal);
normal = normal / (norm_normal + vEpsilon);
@ -678,7 +682,7 @@ T PointTriangle3DistanceForward(
const T t = dot(v0 - p, normal);
const vec3<T> p0 = p + t * normal;
bool is_inside = IsInsideTriangle(p0, v0, v1, v2);
bool is_inside = IsInsideTriangle(p0, v0, v1, v2, min_triangle_area);
T dist = 0.0f;
if ((is_inside) && (norm_normal > kEpsilon)) {
@ -737,6 +741,8 @@ vec3<T> normalize_backward(const vec3<T>& a, const vec3<T>& grad_normz) {
// p: xyz coordinates of a point
// v0, v1, v2: xyz coordinates of the triangle vertices
// grad_dist: Float of the gradient wrt dist
// min_triangle_area: triangles less than this size are considered
// points/lines, IsInsideTriangle returns False
//
// Returns:
// tuple of gradients for the point and triangle:
@ -750,7 +756,8 @@ PointTriangle3DistanceBackward(
const vec3<T>& v0,
const vec3<T>& v1,
const vec3<T>& v2,
const T& grad_dist) {
const T& grad_dist,
const double min_triangle_area) {
const vec3<T> v2v0 = v2 - v0;
const vec3<T> v1v0 = v1 - v0;
const vec3<T> v0p = v0 - p;
@ -764,7 +771,7 @@ PointTriangle3DistanceBackward(
const vec3<T> p0 = p + t * normal;
const vec3<T> diff = t * normal;
bool is_inside = IsInsideTriangle(p0, v0, v1, v2);
bool is_inside = IsInsideTriangle(p0, v0, v1, v2, min_triangle_area);
vec3<T> grad_p(0.0f, 0.0f, 0.0f);
vec3<T> grad_v0(0.0f, 0.0f, 0.0f);

View File

@ -19,6 +19,8 @@ The exact mathematical formulations and implementations of these
distances can be found in `csrc/utils/geometry_utils.cuh`.
"""
_DEFAULT_MIN_TRIANGLE_AREA: float = 5e-3
# PointFaceDistance
class _PointFaceDistance(Function):
@ -27,7 +29,15 @@ class _PointFaceDistance(Function):
"""
@staticmethod
def forward(ctx, points, points_first_idx, tris, tris_first_idx, max_points):
def forward(
ctx,
points,
points_first_idx,
tris,
tris_first_idx,
max_points,
min_triangle_area=_DEFAULT_MIN_TRIANGLE_AREA,
):
"""
Args:
ctx: Context object used to calculate gradients.
@ -39,6 +49,8 @@ class _PointFaceDistance(Function):
tris_first_idx: LongTensor of shape `(N,)` indicating the first face
index in each example in the batch
max_points: Scalar equal to maximum number of points in the batch
min_triangle_area: (float, defaulted) Triangles of area less than this
will be treated as points/lines.
Returns:
dists: FloatTensor of shape `(P,)`, where `dists[p]` is the squared
euclidean distance of `p`-th point to the closest triangular face
@ -53,9 +65,15 @@ class _PointFaceDistance(Function):
"""
dists, idxs = _C.point_face_dist_forward(
points, points_first_idx, tris, tris_first_idx, max_points
points,
points_first_idx,
tris,
tris_first_idx,
max_points,
min_triangle_area,
)
ctx.save_for_backward(points, tris, idxs)
ctx.min_triangle_area = min_triangle_area
return dists
@staticmethod
@ -63,10 +81,11 @@ class _PointFaceDistance(Function):
def backward(ctx, grad_dists):
grad_dists = grad_dists.contiguous()
points, tris, idxs = ctx.saved_tensors
min_triangle_area = ctx.min_triangle_area
grad_points, grad_tris = _C.point_face_dist_backward(
points, tris, idxs, grad_dists
points, tris, idxs, grad_dists, min_triangle_area
)
return grad_points, None, grad_tris, None, None
return grad_points, None, grad_tris, None, None, None
# pyre-fixme[16]: `_PointFaceDistance` has no attribute `apply`.
@ -80,7 +99,15 @@ class _FacePointDistance(Function):
"""
@staticmethod
def forward(ctx, points, points_first_idx, tris, tris_first_idx, max_tris):
def forward(
ctx,
points,
points_first_idx,
tris,
tris_first_idx,
max_tris,
min_triangle_area=_DEFAULT_MIN_TRIANGLE_AREA,
):
"""
Args:
ctx: Context object used to calculate gradients.
@ -92,6 +119,8 @@ class _FacePointDistance(Function):
tris_first_idx: LongTensor of shape `(N,)` indicating the first face
index in each example in the batch
max_tris: Scalar equal to maximum number of faces in the batch
min_triangle_area: (float, defaulted) Triangles of area less than this
will be treated as points/lines.
Returns:
dists: FloatTensor of shape `(T,)`, where `dists[t]` is the squared
euclidean distance of `t`-th triangular face to the closest point in the
@ -104,9 +133,10 @@ class _FacePointDistance(Function):
face `(v0, v1, v2)`.
"""
dists, idxs = _C.face_point_dist_forward(
points, points_first_idx, tris, tris_first_idx, max_tris
points, points_first_idx, tris, tris_first_idx, max_tris, min_triangle_area
)
ctx.save_for_backward(points, tris, idxs)
ctx.min_triangle_area = min_triangle_area
return dists
@staticmethod
@ -114,10 +144,11 @@ class _FacePointDistance(Function):
def backward(ctx, grad_dists):
grad_dists = grad_dists.contiguous()
points, tris, idxs = ctx.saved_tensors
min_triangle_area = ctx.min_triangle_area
grad_points, grad_tris = _C.face_point_dist_backward(
points, tris, idxs, grad_dists
points, tris, idxs, grad_dists, min_triangle_area
)
return grad_points, None, grad_tris, None, None
return grad_points, None, grad_tris, None, None, None
# pyre-fixme[16]: `_FacePointDistance` has no attribute `apply`.
@ -293,7 +324,11 @@ def point_mesh_edge_distance(meshes: Meshes, pcls: Pointclouds):
return point_dist + edge_dist
def point_mesh_face_distance(meshes: Meshes, pcls: Pointclouds):
def point_mesh_face_distance(
meshes: Meshes,
pcls: Pointclouds,
min_triangle_area: float = _DEFAULT_MIN_TRIANGLE_AREA,
):
"""
Computes the distance between a pointcloud and a mesh within a batch.
Given a pair `(mesh, pcl)` in the batch, we define the distance to be the
@ -310,6 +345,8 @@ def point_mesh_face_distance(meshes: Meshes, pcls: Pointclouds):
Args:
meshes: A Meshes data structure containing N meshes
pcls: A Pointclouds data structure containing N pointclouds
min_triangle_area: (float, defaulted) Triangles of area less than this
will be treated as points/lines.
Returns:
loss: The `point_face(mesh, pcl) + face_point(mesh, pcl)` distance
@ -334,7 +371,7 @@ def point_mesh_face_distance(meshes: Meshes, pcls: Pointclouds):
# point to face distance: shape (P,)
point_to_face = point_face_distance(
points, points_first_idx, tris, tris_first_idx, max_points
points, points_first_idx, tris, tris_first_idx, max_points, min_triangle_area
)
# weight each example by the inverse of number of points in the example
@ -347,7 +384,7 @@ def point_mesh_face_distance(meshes: Meshes, pcls: Pointclouds):
# face to point distance: shape (T,)
face_to_point = face_point_distance(
points, points_first_idx, tris, tris_first_idx, max_tris
points, points_first_idx, tris, tris_first_idx, max_tris, min_triangle_area
)
# weight each example by the inverse of number of faces in the example

View File

@ -23,6 +23,10 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
def eps():
return 1e-8
@staticmethod
def min_triangle_area():
return 5e-3
@staticmethod
def init_meshes_clouds(
batch_size: int = 10,
@ -563,8 +567,12 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
grad_tris_naive = tris.grad.cpu()
# Cuda Forward Implementation
dists_cuda = _C.point_face_array_dist_forward(points, tris)
dists_cpu = _C.point_face_array_dist_forward(points_cpu, tris_cpu)
dists_cuda = _C.point_face_array_dist_forward(
points, tris, TestPointMeshDistance.min_triangle_area()
)
dists_cpu = _C.point_face_array_dist_forward(
points_cpu, tris_cpu, TestPointMeshDistance.min_triangle_area()
)
# Compare
self.assertClose(dists_naive.cpu(), dists_cuda.cpu())
@ -572,10 +580,13 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
# CUDA Backward Implementation
grad_points_cuda, grad_tris_cuda = _C.point_face_array_dist_backward(
points, tris, grad_dists
points, tris, grad_dists, TestPointMeshDistance.min_triangle_area()
)
grad_points_cpu, grad_tris_cpu = _C.point_face_array_dist_backward(
points_cpu, tris_cpu, grad_dists.cpu()
points_cpu,
tris_cpu,
grad_dists.cpu(),
TestPointMeshDistance.min_triangle_area(),
)
# Compare
@ -615,12 +626,21 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
# Cuda Implementation: forward
dists_cuda, idx_cuda = _C.point_face_dist_forward(
points_packed, points_first_idx, faces_packed, faces_first_idx, max_p
points_packed,
points_first_idx,
faces_packed,
faces_first_idx,
max_p,
TestPointMeshDistance.min_triangle_area(),
)
# Cuda Implementation: backward
grad_points_cuda, grad_faces_cuda = _C.point_face_dist_backward(
points_packed, faces_packed, idx_cuda, grad_dists
points_packed,
faces_packed,
idx_cuda,
grad_dists,
TestPointMeshDistance.min_triangle_area(),
)
# Cpu Implementation: forward
@ -630,12 +650,17 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
faces_packed.cpu(),
faces_first_idx.cpu(),
max_p,
TestPointMeshDistance.min_triangle_area(),
)
# Cpu Implementation: backward
# Note that using idx_cpu doesn't pass - there seems to be a problem with tied results.
grad_points_cpu, grad_faces_cpu = _C.point_face_dist_backward(
points_packed.cpu(), faces_packed.cpu(), idx_cuda.cpu(), grad_dists.cpu()
points_packed.cpu(),
faces_packed.cpu(),
idx_cuda.cpu(),
grad_dists.cpu(),
TestPointMeshDistance.min_triangle_area(),
)
# Naive Implementation: forward
@ -716,12 +741,21 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
# Cuda Implementation: forward
dists_cuda, idx_cuda = _C.face_point_dist_forward(
points_packed, points_first_idx, faces_packed, faces_first_idx, max_f
points_packed,
points_first_idx,
faces_packed,
faces_first_idx,
max_f,
TestPointMeshDistance.min_triangle_area(),
)
# Cuda Implementation: backward
grad_points_cuda, grad_faces_cuda = _C.face_point_dist_backward(
points_packed, faces_packed, idx_cuda, grad_dists
points_packed,
faces_packed,
idx_cuda,
grad_dists,
TestPointMeshDistance.min_triangle_area(),
)
# Cpu Implementation: forward
@ -731,11 +765,16 @@ class TestPointMeshDistance(TestCaseMixin, unittest.TestCase):
faces_packed.cpu(),
faces_first_idx.cpu(),
max_f,
TestPointMeshDistance.min_triangle_area(),
)
# Cpu Implementation: backward
grad_points_cpu, grad_faces_cpu = _C.face_point_dist_backward(
points_packed.cpu(), faces_packed.cpu(), idx_cpu, grad_dists.cpu()
points_packed.cpu(),
faces_packed.cpu(),
idx_cpu,
grad_dists.cpu(),
TestPointMeshDistance.min_triangle_area(),
)
# Naive Implementation: forward