CPU implementation for point_mesh functions

Summary:
point_mesh functions were missing CPU implementations.
The indices returned are not always matching, possibly due to numerical instability.

Reviewed By: gkioxari

Differential Revision: D21594264

fbshipit-source-id: 3016930e2a9a0f3cd8b3ac4c94a92c9411c0989d
This commit is contained in:
Jeremy Reizenstein
2020-06-15 10:08:15 -07:00
committed by Facebook GitHub Bot
parent 7f1e63aed1
commit 74659aef26
6 changed files with 878 additions and 31 deletions

View File

@@ -0,0 +1,398 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <torch/extension.h>
#include <array>
#include <limits>
#include "utils/geometry_utils.h"
#include "utils/vec3.h"
// - We start with implementations of simple operations on points, edges and
// faces. The hull of H points is a point if H=1, an edge if H=2, a face if H=3.
template <typename T>
vec3<T> ExtractPoint(const at::TensorAccessor<T, 1>& t) {
return vec3(t[0], t[1], t[2]);
}
template <class Accessor>
struct ExtractHullHelper {
template <int H>
static std::array<vec3<std::remove_pointer_t<typename Accessor::PtrType>>, H>
get(const Accessor& t);
template <>
static std::array<vec3<std::remove_pointer_t<typename Accessor::PtrType>>, 1>
get<1>(const Accessor& t) {
return {ExtractPoint(t)};
}
template <>
static std::array<vec3<std::remove_pointer_t<typename Accessor::PtrType>>, 2>
get<2>(const Accessor& t) {
return {ExtractPoint(t[0]), ExtractPoint(t[1])};
}
template <>
static std::array<vec3<std::remove_pointer_t<typename Accessor::PtrType>>, 3>
get<3>(const Accessor& t) {
return {ExtractPoint(t[0]), ExtractPoint(t[1]), ExtractPoint(t[2])};
}
};
template <int H, typename Accessor>
std::array<vec3<std::remove_pointer_t<typename Accessor::PtrType>>, H>
ExtractHull(const Accessor& t) {
return ExtractHullHelper<Accessor>::template get<H>(t);
}
template <typename T>
void IncrementPoint(at::TensorAccessor<T, 1>&& t, const vec3<T>& point) {
t[0] += point.x;
t[1] += point.y;
t[2] += point.z;
}
// distance between the convex hull of A points and B points
// this could be done in c++17 with tuple_cat and invoke
template <typename T>
T HullDistance(
const std::array<vec3<T>, 1>& a,
const std::array<vec3<T>, 2>& b) {
using std::get;
return PointLine3DistanceForward(get<0>(a), get<0>(b), get<1>(b));
}
template <typename T>
T HullDistance(
const std::array<vec3<T>, 1>& a,
const std::array<vec3<T>, 3>& b) {
using std::get;
return PointTriangle3DistanceForward(
get<0>(a), get<0>(b), get<1>(b), get<2>(b));
}
template <typename T>
T HullDistance(
const std::array<vec3<T>, 2>& a,
const std::array<vec3<T>, 1>& b) {
return HullDistance(b, a);
}
template <typename T>
T HullDistance(
const std::array<vec3<T>, 3>& a,
const std::array<vec3<T>, 1>& b) {
return HullDistance(b, a);
}
template <typename T>
void HullHullDistanceBackward(
const std::array<vec3<T>, 1>& a,
const std::array<vec3<T>, 2>& b,
T grad_dist,
at::TensorAccessor<T, 1>&& grad_a,
at::TensorAccessor<T, 2>&& grad_b) {
using std::get;
auto res =
PointLine3DistanceBackward(get<0>(a), get<0>(b), get<1>(b), grad_dist);
IncrementPoint(std::move(grad_a), get<0>(res));
IncrementPoint(grad_b[0], get<1>(res));
IncrementPoint(grad_b[1], get<2>(res));
}
template <typename T>
void HullHullDistanceBackward(
const std::array<vec3<T>, 1>& a,
const std::array<vec3<T>, 3>& b,
T grad_dist,
at::TensorAccessor<T, 1>&& grad_a,
at::TensorAccessor<T, 2>&& grad_b) {
using std::get;
auto res = PointTriangle3DistanceBackward(
get<0>(a), get<0>(b), get<1>(b), get<2>(b), grad_dist);
IncrementPoint(std::move(grad_a), get<0>(res));
IncrementPoint(grad_b[0], get<1>(res));
IncrementPoint(grad_b[1], get<2>(res));
IncrementPoint(grad_b[2], get<3>(res));
}
template <typename T>
void HullHullDistanceBackward(
const std::array<vec3<T>, 3>& a,
const std::array<vec3<T>, 1>& b,
T grad_dist,
at::TensorAccessor<T, 2>&& grad_a,
at::TensorAccessor<T, 1>&& grad_b) {
return HullHullDistanceBackward(
b, a, grad_dist, std::move(grad_b), std::move(grad_a));
}
template <typename T>
void HullHullDistanceBackward(
const std::array<vec3<T>, 2>& a,
const std::array<vec3<T>, 1>& b,
T grad_dist,
at::TensorAccessor<T, 2>&& grad_a,
at::TensorAccessor<T, 1>&& grad_b) {
return HullHullDistanceBackward(
b, a, grad_dist, std::move(grad_b), std::move(grad_a));
}
template <int H>
void ValidateShape(const at::Tensor& as) {
if (H == 1) {
TORCH_CHECK(as.size(1) == 3);
} else {
TORCH_CHECK(as.size(2) == 3 && as.size(1) == H);
}
}
// ----------- Here begins the implementation of each top-level
// function using non-type template parameters to
// implement all the cases in one go. ----------- //
template <int H1, int H2>
std::tuple<at::Tensor, at::Tensor> HullHullDistanceForwardCpu(
const at::Tensor& as,
const at::Tensor& as_first_idx,
const at::Tensor& bs,
const at::Tensor& bs_first_idx) {
const int64_t A_N = as.size(0);
const int64_t B_N = bs.size(0);
const int64_t BATCHES = as_first_idx.size(0);
ValidateShape<H1>(as);
ValidateShape<H2>(bs);
TORCH_CHECK(bs_first_idx.size(0) == BATCHES);
// clang-format off
at::Tensor dists = at::zeros({A_N,}, as.options());
at::Tensor idxs = at::zeros({A_N,}, as_first_idx.options());
// clang-format on
auto as_a = as.accessor < float, H1 == 1 ? 2 : 3 > ();
auto bs_a = bs.accessor < float, H2 == 1 ? 2 : 3 > ();
auto as_first_idx_a = as_first_idx.accessor<int64_t, 1>();
auto bs_first_idx_a = bs_first_idx.accessor<int64_t, 1>();
auto dists_a = dists.accessor<float, 1>();
auto idxs_a = idxs.accessor<int64_t, 1>();
int64_t a_batch_end = 0;
int64_t b_batch_start = 0, b_batch_end = 0;
int64_t batch_idx = 0;
for (int64_t a_n = 0; a_n < A_N; ++a_n) {
if (a_n == a_batch_end) {
++batch_idx;
b_batch_start = b_batch_end;
if (batch_idx == BATCHES) {
a_batch_end = std::numeric_limits<int64_t>::max();
b_batch_end = B_N;
} else {
a_batch_end = as_first_idx_a[batch_idx];
b_batch_end = bs_first_idx_a[batch_idx];
}
}
float min_dist = std::numeric_limits<float>::max();
size_t min_idx = 0;
auto a = ExtractHull<H1>(as_a[a_n]);
for (int64_t b_n = b_batch_start; b_n < b_batch_end; ++b_n) {
float dist = HullDistance(a, ExtractHull<H2>(bs_a[b_n]));
if (dist <= min_dist) {
min_dist = dist;
min_idx = b_n;
}
}
dists_a[a_n] = min_dist;
idxs_a[a_n] = min_idx;
}
return std::make_tuple(dists, idxs);
}
template <int H1, int H2>
std::tuple<at::Tensor, at::Tensor> HullHullDistanceBackwardCpu(
const at::Tensor& as,
const at::Tensor& bs,
const at::Tensor& idx_bs,
const at::Tensor& grad_dists) {
const int64_t A_N = as.size(0);
TORCH_CHECK(idx_bs.size(0) == A_N);
TORCH_CHECK(grad_dists.size(0) == A_N);
ValidateShape<H1>(as);
ValidateShape<H2>(bs);
at::Tensor grad_as = at::zeros_like(as);
at::Tensor grad_bs = at::zeros_like(bs);
auto as_a = as.accessor < float, H1 == 1 ? 2 : 3 > ();
auto bs_a = bs.accessor < float, H2 == 1 ? 2 : 3 > ();
auto grad_as_a = grad_as.accessor < float, H1 == 1 ? 2 : 3 > ();
auto grad_bs_a = grad_bs.accessor < float, H2 == 1 ? 2 : 3 > ();
auto idx_bs_a = idx_bs.accessor<int64_t, 1>();
auto grad_dists_a = grad_dists.accessor<float, 1>();
for (int64_t a_n = 0; a_n < A_N; ++a_n) {
auto a = ExtractHull<H1>(as_a[a_n]);
auto b = ExtractHull<H2>(bs_a[idx_bs_a[a_n]]);
HullHullDistanceBackward(
a, b, grad_dists_a[a_n], grad_as_a[a_n], grad_bs_a[idx_bs_a[a_n]]);
}
return std::make_tuple(grad_as, grad_bs);
}
template <int H>
torch::Tensor PointHullArrayDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& bs) {
const int64_t P = points.size(0);
const int64_t B_N = bs.size(0);
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
ValidateShape<H>(bs);
at::Tensor dists = at::zeros({P, B_N}, points.options());
auto points_a = points.accessor<float, 2>();
auto bs_a = bs.accessor<float, 3>();
auto dists_a = dists.accessor<float, 2>();
for (int64_t p = 0; p < P; ++p) {
auto point = ExtractHull<1>(points_a[p]);
auto dest = dists_a[p];
for (int64_t b_n = 0; b_n < B_N; ++b_n) {
auto b = ExtractHull<H>(bs_a[b_n]);
dest[b_n] = HullDistance(point, b);
}
}
return dists;
}
template <int H>
std::tuple<at::Tensor, at::Tensor> PointHullArrayDistanceBackwardCpu(
const at::Tensor& points,
const at::Tensor& bs,
const at::Tensor& grad_dists) {
const int64_t P = points.size(0);
const int64_t B_N = bs.size(0);
TORCH_CHECK(points.size(1) == 3, "points must be of shape Px3");
ValidateShape<H>(bs);
TORCH_CHECK((grad_dists.size(0) == P) && (grad_dists.size(1) == B_N));
at::Tensor grad_points = at::zeros({P, 3}, points.options());
at::Tensor grad_bs = at::zeros({B_N, H, 3}, bs.options());
auto points_a = points.accessor<float, 2>();
auto bs_a = bs.accessor<float, 3>();
auto grad_dists_a = grad_dists.accessor<float, 2>();
auto grad_points_a = grad_points.accessor<float, 2>();
auto grad_bs_a = grad_bs.accessor<float, 3>();
for (int64_t p = 0; p < P; ++p) {
auto point = ExtractHull<1>(points_a[p]);
auto grad_point = grad_points_a[p];
auto grad_dist = grad_dists_a[p];
for (int64_t b_n = 0; b_n < B_N; ++b_n) {
auto b = ExtractHull<H>(bs_a[b_n]);
HullHullDistanceBackward(
point, b, grad_dist[b_n], std::move(grad_point), grad_bs_a[b_n]);
}
}
return std::make_tuple(grad_points, grad_bs);
}
// ---------- Here begin the exported functions ------------ //
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx) {
return HullHullDistanceForwardCpu<1, 3>(
points, points_first_idx, tris, tris_first_idx);
}
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists) {
return HullHullDistanceBackwardCpu<1, 3>(
points, tris, idx_points, grad_dists);
}
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx) {
return HullHullDistanceForwardCpu<3, 1>(
tris, tris_first_idx, points, points_first_idx);
}
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_tris,
const torch::Tensor& grad_dists) {
auto res =
HullHullDistanceBackwardCpu<3, 1>(tris, points, idx_tris, grad_dists);
return std::make_tuple(std::get<1>(res), std::get<0>(res));
}
torch::Tensor PointEdgeArrayDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& segms) {
return PointHullArrayDistanceForwardCpu<2>(points, segms);
}
std::tuple<at::Tensor, at::Tensor> PointFaceArrayDistanceBackwardCpu(
const at::Tensor& points,
const at::Tensor& tris,
const at::Tensor& grad_dists) {
return PointHullArrayDistanceBackwardCpu<3>(points, tris, grad_dists);
}
torch::Tensor PointFaceArrayDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& tris) {
return PointHullArrayDistanceForwardCpu<3>(points, tris);
}
std::tuple<at::Tensor, at::Tensor> PointEdgeArrayDistanceBackwardCpu(
const at::Tensor& points,
const at::Tensor& segms,
const at::Tensor& grad_dists) {
return PointHullArrayDistanceBackwardCpu<2>(points, segms, grad_dists);
}
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& segms,
const torch::Tensor& segms_first_idx,
const int64_t /*max_points*/) {
return HullHullDistanceForwardCpu<1, 2>(
points, points_first_idx, segms, segms_first_idx);
}
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists) {
return HullHullDistanceBackwardCpu<1, 2>(
points, segms, idx_points, grad_dists);
}
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& segms,
const torch::Tensor& segms_first_idx,
const int64_t /*max_segms*/) {
return HullHullDistanceForwardCpu<2, 1>(
segms, segms_first_idx, points, points_first_idx);
}
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& idx_segms,
const torch::Tensor& grad_dists) {
auto res =
HullHullDistanceBackwardCpu<2, 1>(segms, points, idx_segms, grad_dists);
return std::make_tuple(std::get<1>(res), std::get<0>(res));
}

View File

@@ -46,6 +46,13 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCuda(
const int64_t max_points);
#endif
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& segms,
const torch::Tensor& segms_first_idx,
const int64_t max_points);
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForward(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
@@ -64,7 +71,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
return PointEdgeDistanceForwardCpu(
points, points_first_idx, segms, segms_first_idx, max_points);
}
// Backward pass for PointEdgeDistance.
@@ -91,6 +99,12 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCuda(
const torch::Tensor& grad_dists);
#endif
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists);
std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackward(
const torch::Tensor& points,
const torch::Tensor& segms,
@@ -107,7 +121,7 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
return PointEdgeDistanceBackwardCpu(points, segms, idx_points, grad_dists);
}
// ****************************************************************************
@@ -150,6 +164,13 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCuda(
const int64_t max_segms);
#endif
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& segms,
const torch::Tensor& segms_first_idx,
const int64_t max_segms);
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForward(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
@@ -168,7 +189,8 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
return EdgePointDistanceForwardCpu(
points, points_first_idx, segms, segms_first_idx, max_segms);
}
// Backward pass for EdgePointDistance.
@@ -195,6 +217,12 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCuda(
const torch::Tensor& grad_dists);
#endif
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& idx_segms,
const torch::Tensor& grad_dists);
std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
const torch::Tensor& points,
const torch::Tensor& segms,
@@ -211,7 +239,7 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
return EdgePointDistanceBackwardCpu(points, segms, idx_segms, grad_dists);
}
// ****************************************************************************
@@ -242,6 +270,10 @@ torch::Tensor PointEdgeArrayDistanceForwardCuda(
const torch::Tensor& segms);
#endif
torch::Tensor PointEdgeArrayDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& segms);
torch::Tensor PointEdgeArrayDistanceForward(
const torch::Tensor& points,
const torch::Tensor& segms) {
@@ -254,7 +286,7 @@ torch::Tensor PointEdgeArrayDistanceForward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
return PointEdgeArrayDistanceForwardCpu(points, segms);
}
// Backward pass for PointEdgeArrayDistance.
@@ -277,6 +309,11 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCuda(
const torch::Tensor& grad_dists);
#endif
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& segms,
const torch::Tensor& grad_dists);
std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackward(
const torch::Tensor& points,
const torch::Tensor& segms,
@@ -291,5 +328,5 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
return PointEdgeArrayDistanceBackwardCpu(points, segms, grad_dists);
}

View File

@@ -19,7 +19,7 @@
// points_first_idx: LongTensor of shape (N,) indicating the first point
// index for each example in the batch
// tris: FloatTensor of shape (T, 3, 3) of the triangular faces. The t-th
// triangulare face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2])
// triangular face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2])
// tris_first_idx: LongTensor of shape (N,) indicating the first face
// index for each example in the batch
// max_points: Scalar equal to max(P_i) for i in [0, N - 1] containing
@@ -48,6 +48,12 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForwardCuda(
const int64_t max_points);
#endif
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx);
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForward(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
@@ -66,7 +72,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
return PointFaceDistanceForwardCpu(
points, points_first_idx, tris, tris_first_idx);
}
// Backward pass for PointFaceDistance.
@@ -92,6 +99,11 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackwardCuda(
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists);
#endif
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_points,
const torch::Tensor& grad_dists);
std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackward(
const torch::Tensor& points,
@@ -109,7 +121,7 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
return PointFaceDistanceBackwardCpu(points, tris, idx_points, grad_dists);
}
// ****************************************************************************
@@ -124,7 +136,7 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackward(
// points_first_idx: LongTensor of shape (N,) indicating the first point
// index for each example in the batch
// tris: FloatTensor of shape (T, 3, 3) of the triangular faces. The t-th
// triangulare face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2])
// triangular face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2])
// tris_first_idx: LongTensor of shape (N,) indicating the first face
// index for each example in the batch
// max_tris: Scalar equal to max(T_i) for i in [0, N - 1] containing
@@ -149,9 +161,15 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForwardCuda(
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx,
const int64_t max_tros);
const int64_t max_tris);
#endif
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
const torch::Tensor& tris,
const torch::Tensor& tris_first_idx);
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForward(
const torch::Tensor& points,
const torch::Tensor& points_first_idx,
@@ -170,7 +188,8 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
return FacePointDistanceForwardCpu(
points, points_first_idx, tris, tris_first_idx);
}
// Backward pass for FacePointDistance.
@@ -197,6 +216,12 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackwardCuda(
const torch::Tensor& grad_dists);
#endif
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& idx_tris,
const torch::Tensor& grad_dists);
std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackward(
const torch::Tensor& points,
const torch::Tensor& tris,
@@ -213,7 +238,7 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
return FacePointDistanceBackwardCpu(points, tris, idx_tris, grad_dists);
}
// ****************************************************************************
@@ -226,7 +251,7 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackward(
// Args:
// points: FloatTensor of shape (P, 3)
// tris: FloatTensor of shape (T, 3, 3) of the triangular faces. The t-th
// triangulare face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2])
// triangular face is spanned by (tris[t, 0], tris[t, 1], tris[t, 2])
//
// Returns:
// dists: FloatTensor of shape (P, T), where dists[p, t] is the squared
@@ -245,6 +270,10 @@ torch::Tensor PointFaceArrayDistanceForwardCuda(
const torch::Tensor& tris);
#endif
torch::Tensor PointFaceArrayDistanceForwardCpu(
const torch::Tensor& points,
const torch::Tensor& tris);
torch::Tensor PointFaceArrayDistanceForward(
const torch::Tensor& points,
const torch::Tensor& tris) {
@@ -257,7 +286,7 @@ torch::Tensor PointFaceArrayDistanceForward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
return PointFaceArrayDistanceForwardCpu(points, tris);
}
// Backward pass for PointFaceArrayDistance.
@@ -278,6 +307,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackwardCuda(
const torch::Tensor& tris,
const torch::Tensor& grad_dists);
#endif
std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackwardCpu(
const torch::Tensor& points,
const torch::Tensor& tris,
const torch::Tensor& grad_dists);
std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackward(
const torch::Tensor& points,
@@ -293,5 +326,5 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
AT_ERROR("No CPU implementation.");
return PointFaceArrayDistanceBackwardCpu(points, tris, grad_dists);
}