point mesh distances

Summary:
Implementation of point to mesh distances. The current diff contains two types:
(a) Point to Edge
(b) Point to Face

```

Benchmark                                       Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
POINT_MESH_EDGE_4_100_300_5000_cuda:0                2745            3138            183
POINT_MESH_EDGE_4_100_300_10000_cuda:0               4408            4499            114
POINT_MESH_EDGE_4_100_3000_5000_cuda:0               4978            5070            101
POINT_MESH_EDGE_4_100_3000_10000_cuda:0              9076            9187             56
POINT_MESH_EDGE_4_1000_300_5000_cuda:0               1411            1487            355
POINT_MESH_EDGE_4_1000_300_10000_cuda:0              4829            5030            104
POINT_MESH_EDGE_4_1000_3000_5000_cuda:0              7539            7620             67
POINT_MESH_EDGE_4_1000_3000_10000_cuda:0            12088           12272             42
POINT_MESH_EDGE_8_100_300_5000_cuda:0                3106            3222            161
POINT_MESH_EDGE_8_100_300_10000_cuda:0               8561            8648             59
POINT_MESH_EDGE_8_100_3000_5000_cuda:0               6932            7021             73
POINT_MESH_EDGE_8_100_3000_10000_cuda:0             24032           24176             21
POINT_MESH_EDGE_8_1000_300_5000_cuda:0               5272            5399             95
POINT_MESH_EDGE_8_1000_300_10000_cuda:0             11348           11430             45
POINT_MESH_EDGE_8_1000_3000_5000_cuda:0             17478           17683             29
POINT_MESH_EDGE_8_1000_3000_10000_cuda:0            25961           26236             20
POINT_MESH_EDGE_16_100_300_5000_cuda:0               8244            8323             61
POINT_MESH_EDGE_16_100_300_10000_cuda:0             18018           18071             28
POINT_MESH_EDGE_16_100_3000_5000_cuda:0             19428           19544             26
POINT_MESH_EDGE_16_100_3000_10000_cuda:0            44967           45135             12
POINT_MESH_EDGE_16_1000_300_5000_cuda:0              7825            7937             64
POINT_MESH_EDGE_16_1000_300_10000_cuda:0            18504           18571             28
POINT_MESH_EDGE_16_1000_3000_5000_cuda:0            65805           66132              8
POINT_MESH_EDGE_16_1000_3000_10000_cuda:0           90885           91089              6
--------------------------------------------------------------------------------

Benchmark                                       Avg Time(μs)      Peak Time(μs) Iterations
--------------------------------------------------------------------------------
POINT_MESH_FACE_4_100_300_5000_cuda:0                1561            1685            321
POINT_MESH_FACE_4_100_300_10000_cuda:0               2818            2954            178
POINT_MESH_FACE_4_100_3000_5000_cuda:0              15893           16018             32
POINT_MESH_FACE_4_100_3000_10000_cuda:0             16350           16439             31
POINT_MESH_FACE_4_1000_300_5000_cuda:0               3179            3278            158
POINT_MESH_FACE_4_1000_300_10000_cuda:0              2353            2436            213
POINT_MESH_FACE_4_1000_3000_5000_cuda:0             16262           16336             31
POINT_MESH_FACE_4_1000_3000_10000_cuda:0             9334            9448             54
POINT_MESH_FACE_8_100_300_5000_cuda:0                4377            4493            115
POINT_MESH_FACE_8_100_300_10000_cuda:0               9728            9822             52
POINT_MESH_FACE_8_100_3000_5000_cuda:0              26428           26544             19
POINT_MESH_FACE_8_100_3000_10000_cuda:0             42238           43031             12
POINT_MESH_FACE_8_1000_300_5000_cuda:0               3891            3982            129
POINT_MESH_FACE_8_1000_300_10000_cuda:0              5363            5429             94
POINT_MESH_FACE_8_1000_3000_5000_cuda:0             20998           21084             24
POINT_MESH_FACE_8_1000_3000_10000_cuda:0            39711           39897             13
POINT_MESH_FACE_16_100_300_5000_cuda:0               5955            6001             84
POINT_MESH_FACE_16_100_300_10000_cuda:0             12082           12144             42
POINT_MESH_FACE_16_100_3000_5000_cuda:0             44996           45176             12
POINT_MESH_FACE_16_100_3000_10000_cuda:0            73042           73197              7
POINT_MESH_FACE_16_1000_300_5000_cuda:0              8292            8374             61
POINT_MESH_FACE_16_1000_300_10000_cuda:0            19442           19506             26
POINT_MESH_FACE_16_1000_3000_5000_cuda:0            36059           36194             14
POINT_MESH_FACE_16_1000_3000_10000_cuda:0           64644           64822              8
--------------------------------------------------------------------------------
```

Reviewed By: jcjohnson

Differential Revision: D20590462

fbshipit-source-id: 42a39837b514a546ac9471bfaff60eefe7fae829
This commit is contained in:
Georgia Gkioxari
2020-04-11 00:18:53 -07:00
committed by Facebook GitHub Bot
parent 474c8b456a
commit 487d4d6607
33 changed files with 3437 additions and 84 deletions

View File

@@ -0,0 +1,343 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
//
// This file provides utilities for dispatching to specialized versions of
// functions. This is especially useful for CUDA kernels, since specializing
// them to particular input sizes can often allow the compiler to unroll loops
// and place arrays into registers, which can give huge performance speedups.
//
// As an example, suppose we have the following function which is specialized
// based on a compile-time int64_t value:
//
// template<typename T, int64_t x>
// struct SquareOffset {
// static void run(T y) {
// T val = x * x + y;
// std::cout << val << std::endl;
// }
// }
//
// This function takes one compile-time argument x, and one run-time argument y.
// We might want to compile specialized versions of this for x=0, x=1, etc and
// then dispatch to the correct one based on the runtime value of x.
// One simple way to achieve this is with a lookup table:
//
// template<typename T>
// void DispatchSquareOffset(const int64_t x, T y) {
// if (x == 0) {
// SquareOffset<T, 0>::run(y);
// } else if (x == 1) {
// SquareOffset<T, 1>::run(y);
// } else if (x == 2) {
// SquareOffset<T, 2>::run(y);
// }
// }
//
// This function takes both x and y as run-time arguments, and dispatches to
// different specialized versions of SquareOffset based on the run-time value
// of x. This works, but it's tedious and error-prone. If we want to change the
// set of x values for which we provide compile-time specializations, then we
// will need to do a lot of tedius editing of the dispatch function. Also, if we
// want to provide compile-time specializations for another function other than
// SquareOffset, we will need to duplicate the entire lookup table.
//
// To solve these problems, we can use the DispatchKernel1D function provided by
// this file instead:
//
// template<typename T>
// void DispatchSquareOffset(const int64_t x, T y) {
// constexpr int64_t xmin = 0;
// constexpr int64_t xmax = 2;
// DispatchKernel1D<SquareOffset, T, xmin, xmax>(x, y);
// }
//
// DispatchKernel1D uses template metaprogramming to compile specialized
// versions of SquareOffset for all values of x with xmin <= x <= xmax, and
// then dispatches to the correct one based on the run-time value of x. If we
// want to change the range of x values for which SquareOffset is specialized
// at compile-time, then all we have to do is change the values of the
// compile-time constants xmin and xmax.
//
// This file also allows us to similarly dispatch functions that depend on two
// compile-time int64_t values, using the DispatchKernel2D function like this:
//
// template<typename T, int64_t x, int64_t y>
// struct Sum {
// static void run(T z, T w) {
// T val = x + y + z + w;
// std::cout << val << std::endl;
// }
// }
//
// template<typename T>
// void DispatchSum(const int64_t x, const int64_t y, int z, int w) {
// constexpr int64_t xmin = 1;
// constexpr int64_t xmax = 3;
// constexpr int64_t ymin = 2;
// constexpr int64_t ymax = 5;
// DispatchKernel2D<Sum, T, xmin, xmax, ymin, ymax>(x, y, z, w);
// }
//
// Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to
// compile specialized versions of sum for all values of (x, y) with
// xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct
// specialized version based on the runtime values of x and y.
// Define some helper structs in an anonymous namespace.
namespace {
// 1D dispatch: general case.
// Kernel is the function we want to dispatch to; it should take a typename and
// an int64_t as template args, and it should define a static void function
// run which takes any number of arguments of any type.
// In order to dispatch, we will take an additional template argument curN,
// and increment it via template recursion until it is equal to the run-time
// argument N.
template <
template <typename, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t curN,
typename... Args>
struct DispatchKernelHelper1D {
static void run(const int64_t N, Args... args) {
if (curN == N) {
// The compile-time value curN is equal to the run-time value N, so we
// can dispatch to the run method of the Kernel.
Kernel<T, curN>::run(args...);
} else if (curN < N) {
// Increment curN via template recursion
DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(
N, args...);
}
// We shouldn't get here -- throw an error?
}
};
// 1D dispatch: Specialization when curN == maxN
// We need this base case to avoid infinite template recursion.
template <
template <typename, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
typename... Args>
struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
static void run(const int64_t N, Args... args) {
if (N == maxN) {
Kernel<T, maxN>::run(args...);
}
// We shouldn't get here -- throw an error?
}
};
// 2D dispatch, general case.
// This is similar to the 1D case: we take additional template args curN and
// curM, and increment them via template recursion until they are equal to
// the run-time values of N and M, at which point we dispatch to the run
// method of the kernel.
template <
template <typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t curN,
int64_t minM,
int64_t maxM,
int64_t curM,
typename... Args>
struct DispatchKernelHelper2D {
static void run(const int64_t N, const int64_t M, Args... args) {
if (curN == N && curM == M) {
Kernel<T, curN, curM>::run(args...);
} else if (curN < N && curM < M) {
// Increment both curN and curM. This isn't strictly necessary; we could
// just increment one or the other at each step. But this helps to cut
// on the number of recursive calls we make.
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN + 1,
minM,
maxM,
curM + 1,
Args...>::run(N, M, args...);
} else if (curN < N) {
// Increment curN only
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN + 1,
minM,
maxM,
curM,
Args...>::run(N, M, args...);
} else if (curM < M) {
// Increment curM only
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN,
minM,
maxM,
curM + 1,
Args...>::run(N, M, args...);
}
}
};
// 2D dispatch, specialization for curN == maxN
template <
template <typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t minM,
int64_t maxM,
int64_t curM,
typename... Args>
struct DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
maxN,
minM,
maxM,
curM,
Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (maxN == N && curM == M) {
Kernel<T, maxN, curM>::run(args...);
} else if (curM < maxM) {
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
maxN,
minM,
maxM,
curM + 1,
Args...>::run(N, M, args...);
}
// We should not get here -- throw an error?
}
};
// 2D dispatch, specialization for curM == maxM
template <
template <typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t curN,
int64_t minM,
int64_t maxM,
typename... Args>
struct DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN,
minM,
maxM,
maxM,
Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (curN == N && maxM == M) {
Kernel<T, curN, maxM>::run(args...);
} else if (curN < maxN) {
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
curN + 1,
minM,
maxM,
maxM,
Args...>::run(N, M, args...);
}
// We should not get here -- throw an error?
}
};
// 2D dispatch, specialization for curN == maxN, curM == maxM
template <
template <typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t minM,
int64_t maxM,
typename... Args>
struct DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
maxN,
minM,
maxM,
maxM,
Args...> {
static void run(const int64_t N, const int64_t M, Args... args) {
if (maxN == N && maxM == M) {
Kernel<T, maxN, maxM>::run(args...);
}
// We should not get here -- throw an error?
}
};
} // namespace
// This is the function we expect users to call to dispatch to 1D functions
template <
template <typename, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
typename... Args>
void DispatchKernel1D(const int64_t N, Args... args) {
if (minN <= N && N <= maxN) {
// Kick off the template recursion by calling the Helper with curN = minN
DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(
N, args...);
}
// Maybe throw an error if we tried to dispatch outside the allowed range?
}
// This is the function we expect users to call to dispatch to 2D functions
template <
template <typename, int64_t, int64_t> class Kernel,
typename T,
int64_t minN,
int64_t maxN,
int64_t minM,
int64_t maxM,
typename... Args>
void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) {
if (minN <= N && N <= maxN && minM <= M && M <= maxM) {
// Kick off the template recursion by calling the Helper with curN = minN
// and curM = minM
DispatchKernelHelper2D<
Kernel,
T,
minN,
maxN,
minN,
minM,
maxM,
minM,
Args...>::run(N, M, args...);
}
// Maybe throw an error if we tried to dispatch outside the specified range?
}

View File

@@ -0,0 +1,139 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <thrust/tuple.h>
// Set epsilon
#ifdef _MSC_VER
#define vEpsilon 1e-8f
#else
const auto vEpsilon = 1e-8;
#endif
// Common functions and operators for float2.
__device__ inline float2 operator-(const float2& a, const float2& b) {
return make_float2(a.x - b.x, a.y - b.y);
}
__device__ inline float2 operator+(const float2& a, const float2& b) {
return make_float2(a.x + b.x, a.y + b.y);
}
__device__ inline float2 operator/(const float2& a, const float2& b) {
return make_float2(a.x / b.x, a.y / b.y);
}
__device__ inline float2 operator/(const float2& a, const float b) {
return make_float2(a.x / b, a.y / b);
}
__device__ inline float2 operator*(const float2& a, const float2& b) {
return make_float2(a.x * b.x, a.y * b.y);
}
__device__ inline float2 operator*(const float a, const float2& b) {
return make_float2(a * b.x, a * b.y);
}
__device__ inline float dot(const float2& a, const float2& b) {
return a.x * b.x + a.y * b.y;
}
// Backward pass for the dot product.
// Args:
// a, b: Coordinates of two points.
// grad_dot: Upstream gradient for the output.
//
// Returns:
// tuple of gradients for each of the input points:
// (float2 grad_a, float2 grad_b)
//
__device__ inline thrust::tuple<float2, float2>
DotBackward(const float2& a, const float2& b, const float& grad_dot) {
return thrust::make_tuple(grad_dot * b, grad_dot * a);
}
__device__ inline float sum(const float2& a) {
return a.x + a.y;
}
// Common functions and operators for float3.
__device__ inline float3 operator-(const float3& a, const float3& b) {
return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
}
__device__ inline float3 operator+(const float3& a, const float3& b) {
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
}
__device__ inline float3 operator/(const float3& a, const float3& b) {
return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
}
__device__ inline float3 operator/(const float3& a, const float b) {
return make_float3(a.x / b, a.y / b, a.z / b);
}
__device__ inline float3 operator*(const float3& a, const float3& b) {
return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
}
__device__ inline float3 operator*(const float a, const float3& b) {
return make_float3(a * b.x, a * b.y, a * b.z);
}
__device__ inline float dot(const float3& a, const float3& b) {
return a.x * b.x + a.y * b.y + a.z * b.z;
}
__device__ inline float sum(const float3& a) {
return a.x + a.y + a.z;
}
__device__ inline float3 cross(const float3& a, const float3& b) {
return make_float3(
a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x);
}
__device__ inline thrust::tuple<float3, float3>
cross_backward(const float3& a, const float3& b, const float3& grad_cross) {
const float grad_ax = -grad_cross.y * b.z + grad_cross.z * b.y;
const float grad_ay = grad_cross.x * b.z - grad_cross.z * b.x;
const float grad_az = -grad_cross.x * b.y + grad_cross.y * b.x;
const float3 grad_a = make_float3(grad_ax, grad_ay, grad_az);
const float grad_bx = grad_cross.y * a.z - grad_cross.z * a.y;
const float grad_by = -grad_cross.x * a.z + grad_cross.z * a.x;
const float grad_bz = grad_cross.x * a.y - grad_cross.y * a.x;
const float3 grad_b = make_float3(grad_bx, grad_by, grad_bz);
return thrust::make_tuple(grad_a, grad_b);
}
__device__ inline float norm(const float3& a) {
return sqrt(dot(a, a));
}
__device__ inline float3 normalize(const float3& a) {
return a / (norm(a) + vEpsilon);
}
__device__ inline float3 normalize_backward(
const float3& a,
const float3& grad_normz) {
const float a_norm = norm(a) + vEpsilon;
const float3 out = a / a_norm;
const float grad_ax = grad_normz.x * (1.0f - out.x * out.x) / a_norm +
grad_normz.y * (-out.x * out.y) / a_norm +
grad_normz.z * (-out.x * out.z) / a_norm;
const float grad_ay = grad_normz.x * (-out.x * out.y) / a_norm +
grad_normz.y * (1.0f - out.y * out.y) / a_norm +
grad_normz.z * (-out.y * out.z) / a_norm;
const float grad_az = grad_normz.x * (-out.x * out.z) / a_norm +
grad_normz.y * (-out.y * out.z) / a_norm +
grad_normz.z * (1.0f - out.z * out.z) / a_norm;
return make_float3(grad_ax, grad_ay, grad_az);
}

View File

@@ -0,0 +1,651 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <float.h>
#include <math.h>
#include <torch/extension.h>
#include <cstdio>
#include "float_math.cuh"
// Set epsilon for preventing floating point errors and division by 0.
#ifdef _MSC_VER
#define kEpsilon 1e-8f
#else
const auto kEpsilon = 1e-8;
#endif
// ************************************************************* //
// vec2 utils //
// ************************************************************* //
// Determines whether a point p is on the right side of a 2D line segment
// given by the end points v0, v1.
//
// Args:
// p: vec2 Coordinates of a point.
// v0, v1: vec2 Coordinates of the end points of the edge.
//
// Returns:
// area: The signed area of the parallelogram given by the vectors
// A = p - v0
// B = v1 - v0
//
__device__ inline float
EdgeFunctionForward(const float2& p, const float2& v0, const float2& v1) {
return (p.x - v0.x) * (v1.y - v0.y) - (p.y - v0.y) * (v1.x - v0.x);
}
// Backward pass for the edge function returning partial dervivatives for each
// of the input points.
//
// Args:
// p: vec2 Coordinates of a point.
// v0, v1: vec2 Coordinates of the end points of the edge.
// grad_edge: Upstream gradient for output from edge function.
//
// Returns:
// tuple of gradients for each of the input points:
// (float2 d_edge_dp, float2 d_edge_dv0, float2 d_edge_dv1)
//
__device__ inline thrust::tuple<float2, float2, float2> EdgeFunctionBackward(
const float2& p,
const float2& v0,
const float2& v1,
const float& grad_edge) {
const float2 dedge_dp = make_float2(v1.y - v0.y, v0.x - v1.x);
const float2 dedge_dv0 = make_float2(p.y - v1.y, v1.x - p.x);
const float2 dedge_dv1 = make_float2(v0.y - p.y, p.x - v0.x);
return thrust::make_tuple(
grad_edge * dedge_dp, grad_edge * dedge_dv0, grad_edge * dedge_dv1);
}
// The forward pass for computing the barycentric coordinates of a point
// relative to a triangle.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the triangle vertices.
//
// Returns
// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1].
//
__device__ inline float3 BarycentricCoordsForward(
const float2& p,
const float2& v0,
const float2& v1,
const float2& v2) {
const float area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
const float w0 = EdgeFunctionForward(p, v1, v2) / area;
const float w1 = EdgeFunctionForward(p, v2, v0) / area;
const float w2 = EdgeFunctionForward(p, v0, v1) / area;
return make_float3(w0, w1, w2);
}
// The backward pass for computing the barycentric coordinates of a point
// relative to a triangle.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: (x, y) coordinates of the triangle vertices.
// grad_bary_upstream: vec3<T> Upstream gradient for each of the
// barycentric coordaintes [grad_w0, grad_w1, grad_w2].
//
// Returns
// tuple of gradients for each of the triangle vertices:
// (float2 grad_v0, float2 grad_v1, float2 grad_v2)
//
__device__ inline thrust::tuple<float2, float2, float2, float2>
BarycentricCoordsBackward(
const float2& p,
const float2& v0,
const float2& v1,
const float2& v2,
const float3& grad_bary_upstream) {
const float area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
const float area2 = pow(area, 2.0f);
const float e0 = EdgeFunctionForward(p, v1, v2);
const float e1 = EdgeFunctionForward(p, v2, v0);
const float e2 = EdgeFunctionForward(p, v0, v1);
const float grad_w0 = grad_bary_upstream.x;
const float grad_w1 = grad_bary_upstream.y;
const float grad_w2 = grad_bary_upstream.z;
// Calculate component of the gradient from each of w0, w1 and w2.
// e.g. for w0:
// dloss/dw0_v = dl/dw0 * dw0/dw0_top * dw0_top/dv
// + dl/dw0 * dw0/dw0_bot * dw0_bot/dv
const float dw0_darea = -e0 / (area2);
const float dw0_e0 = 1 / area;
const float dloss_d_w0area = grad_w0 * dw0_darea;
const float dloss_e0 = grad_w0 * dw0_e0;
auto de0_dv = EdgeFunctionBackward(p, v1, v2, dloss_e0);
auto dw0area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w0area);
const float2 dw0_p = thrust::get<0>(de0_dv);
const float2 dw0_dv0 = thrust::get<1>(dw0area_dv);
const float2 dw0_dv1 = thrust::get<1>(de0_dv) + thrust::get<2>(dw0area_dv);
const float2 dw0_dv2 = thrust::get<2>(de0_dv) + thrust::get<0>(dw0area_dv);
const float dw1_darea = -e1 / (area2);
const float dw1_e1 = 1 / area;
const float dloss_d_w1area = grad_w1 * dw1_darea;
const float dloss_e1 = grad_w1 * dw1_e1;
auto de1_dv = EdgeFunctionBackward(p, v2, v0, dloss_e1);
auto dw1area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w1area);
const float2 dw1_p = thrust::get<0>(de1_dv);
const float2 dw1_dv0 = thrust::get<2>(de1_dv) + thrust::get<1>(dw1area_dv);
const float2 dw1_dv1 = thrust::get<2>(dw1area_dv);
const float2 dw1_dv2 = thrust::get<1>(de1_dv) + thrust::get<0>(dw1area_dv);
const float dw2_darea = -e2 / (area2);
const float dw2_e2 = 1 / area;
const float dloss_d_w2area = grad_w2 * dw2_darea;
const float dloss_e2 = grad_w2 * dw2_e2;
auto de2_dv = EdgeFunctionBackward(p, v0, v1, dloss_e2);
auto dw2area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w2area);
const float2 dw2_p = thrust::get<0>(de2_dv);
const float2 dw2_dv0 = thrust::get<1>(de2_dv) + thrust::get<1>(dw2area_dv);
const float2 dw2_dv1 = thrust::get<2>(de2_dv) + thrust::get<2>(dw2area_dv);
const float2 dw2_dv2 = thrust::get<0>(dw2area_dv);
const float2 dbary_p = dw0_p + dw1_p + dw2_p;
const float2 dbary_dv0 = dw0_dv0 + dw1_dv0 + dw2_dv0;
const float2 dbary_dv1 = dw0_dv1 + dw1_dv1 + dw2_dv1;
const float2 dbary_dv2 = dw0_dv2 + dw1_dv2 + dw2_dv2;
return thrust::make_tuple(dbary_p, dbary_dv0, dbary_dv1, dbary_dv2);
}
// Forward pass for applying perspective correction to barycentric coordinates.
//
// Args:
// bary: Screen-space barycentric coordinates for a point
// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
//
// Returns
// World-space barycentric coordinates
//
__device__ inline float3 BarycentricPerspectiveCorrectionForward(
const float3& bary,
const float z0,
const float z1,
const float z2) {
const float w0_top = bary.x * z1 * z2;
const float w1_top = z0 * bary.y * z2;
const float w2_top = z0 * z1 * bary.z;
const float denom = w0_top + w1_top + w2_top;
const float w0 = w0_top / denom;
const float w1 = w1_top / denom;
const float w2 = w2_top / denom;
return make_float3(w0, w1, w2);
}
// Backward pass for applying perspective correction to barycentric coordinates.
//
// Args:
// bary: Screen-space barycentric coordinates for a point
// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
// grad_out: Upstream gradient of the loss with respect to the corrected
// barycentric coordinates.
//
// Returns a tuple of:
// grad_bary: Downstream gradient of the loss with respect to the the
// uncorrected barycentric coordinates.
// grad_z0, grad_z1, grad_z2: Downstream gradient of the loss with respect
// to the z-coordinates of the triangle verts
__device__ inline thrust::tuple<float3, float, float, float>
BarycentricPerspectiveCorrectionBackward(
const float3& bary,
const float z0,
const float z1,
const float z2,
const float3& grad_out) {
// Recompute forward pass
const float w0_top = bary.x * z1 * z2;
const float w1_top = z0 * bary.y * z2;
const float w2_top = z0 * z1 * bary.z;
const float denom = w0_top + w1_top + w2_top;
// Now do backward pass
const float grad_denom_top =
-w0_top * grad_out.x - w1_top * grad_out.y - w2_top * grad_out.z;
const float grad_denom = grad_denom_top / (denom * denom);
const float grad_w0_top = grad_denom + grad_out.x / denom;
const float grad_w1_top = grad_denom + grad_out.y / denom;
const float grad_w2_top = grad_denom + grad_out.z / denom;
const float grad_bary_x = grad_w0_top * z1 * z2;
const float grad_bary_y = grad_w1_top * z0 * z2;
const float grad_bary_z = grad_w2_top * z0 * z1;
const float3 grad_bary = make_float3(grad_bary_x, grad_bary_y, grad_bary_z);
const float grad_z0 = grad_w1_top * bary.y * z2 + grad_w2_top * bary.z * z1;
const float grad_z1 = grad_w0_top * bary.x * z2 + grad_w2_top * bary.z * z0;
const float grad_z2 = grad_w0_top * bary.x * z1 + grad_w1_top * bary.y * z0;
return thrust::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
}
// Calculate minimum squared distance between a line segment (v1 - v0) and a
// point p.
//
// Args:
// p: Coordinates of a point.
// v0, v1: Coordinates of the end points of the line segment.
//
// Returns:
// squared distance to the boundary of the triangle.
//
__device__ inline float
PointLineDistanceForward(const float2& p, const float2& a, const float2& b) {
const float2 ba = b - a;
float l2 = dot(ba, ba);
float t = dot(ba, p - a) / l2;
if (l2 <= kEpsilon) {
return dot(p - b, p - b);
}
t = __saturatef(t); // clamp to the interval [+0.0, 1.0]
const float2 p_proj = a + t * ba;
const float2 d = (p_proj - p);
return dot(d, d); // squared distance
}
// Backward pass for point to line distance in 2D.
//
// Args:
// p: Coordinates of a point.
// v0, v1: Coordinates of the end points of the line segment.
// grad_dist: Upstream gradient for the distance.
//
// Returns:
// tuple of gradients for each of the input points:
// (float2 grad_p, float2 grad_v0, float2 grad_v1)
//
__device__ inline thrust::tuple<float2, float2, float2>
PointLineDistanceBackward(
const float2& p,
const float2& v0,
const float2& v1,
const float& grad_dist) {
// Redo some of the forward pass calculations.
const float2 v1v0 = v1 - v0;
const float2 pv0 = p - v0;
const float t_bot = dot(v1v0, v1v0);
const float t_top = dot(v1v0, pv0);
float tt = t_top / t_bot;
tt = __saturatef(tt);
const float2 p_proj = (1.0f - tt) * v0 + tt * v1;
const float2 d = p - p_proj;
const float dist = sqrt(dot(d, d));
const float2 grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
const float2 grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);
const float2 grad_v1 = grad_dist * tt * 2.0f * (p_proj - p);
return thrust::make_tuple(grad_p, grad_v0, grad_v1);
}
// The forward pass for calculating the shortest distance between a point
// and a triangle.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the three triangle vertices.
//
// Returns:
// shortest squared distance from a point to a triangle.
//
__device__ inline float PointTriangleDistanceForward(
const float2& p,
const float2& v0,
const float2& v1,
const float2& v2) {
// Compute distance to all 3 edges of the triangle and return the min.
const float e01_dist = PointLineDistanceForward(p, v0, v1);
const float e02_dist = PointLineDistanceForward(p, v0, v2);
const float e12_dist = PointLineDistanceForward(p, v1, v2);
const float edge_dist = fminf(fminf(e01_dist, e02_dist), e12_dist);
return edge_dist;
}
// Backward pass for point triangle distance.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the three triangle vertices.
// grad_dist: Upstream gradient for the distance.
//
// Returns:
// tuple of gradients for each of the triangle vertices:
// (float2 grad_v0, float2 grad_v1, float2 grad_v2)
//
__device__ inline thrust::tuple<float2, float2, float2, float2>
PointTriangleDistanceBackward(
const float2& p,
const float2& v0,
const float2& v1,
const float2& v2,
const float& grad_dist) {
// Compute distance to all 3 edges of the triangle.
const float e01_dist = PointLineDistanceForward(p, v0, v1);
const float e02_dist = PointLineDistanceForward(p, v0, v2);
const float e12_dist = PointLineDistanceForward(p, v1, v2);
// Initialize output tensors.
float2 grad_v0 = make_float2(0.0f, 0.0f);
float2 grad_v1 = make_float2(0.0f, 0.0f);
float2 grad_v2 = make_float2(0.0f, 0.0f);
float2 grad_p = make_float2(0.0f, 0.0f);
// Find which edge is the closest and return PointLineDistanceBackward for
// that edge.
if (e01_dist <= e02_dist && e01_dist <= e12_dist) {
// Closest edge is v1 - v0.
auto grad_e01 = PointLineDistanceBackward(p, v0, v1, grad_dist);
grad_p = thrust::get<0>(grad_e01);
grad_v0 = thrust::get<1>(grad_e01);
grad_v1 = thrust::get<2>(grad_e01);
} else if (e02_dist <= e01_dist && e02_dist <= e12_dist) {
// Closest edge is v2 - v0.
auto grad_e02 = PointLineDistanceBackward(p, v0, v2, grad_dist);
grad_p = thrust::get<0>(grad_e02);
grad_v0 = thrust::get<1>(grad_e02);
grad_v2 = thrust::get<2>(grad_e02);
} else if (e12_dist <= e01_dist && e12_dist <= e02_dist) {
// Closest edge is v2 - v1.
auto grad_e12 = PointLineDistanceBackward(p, v1, v2, grad_dist);
grad_p = thrust::get<0>(grad_e12);
grad_v1 = thrust::get<1>(grad_e12);
grad_v2 = thrust::get<2>(grad_e12);
}
return thrust::make_tuple(grad_p, grad_v0, grad_v1, grad_v2);
}
// ************************************************************* //
// vec3 utils //
// ************************************************************* //
// Computes the barycentric coordinates of a point p relative
// to a triangle (v0, v1, v2), i.e. p = w0 * v0 + w1 * v1 + w2 * v2
// s.t. w0 + w1 + w2 = 1.0
//
// NOTE that this function assumes that p lives on the space spanned
// by (v0, v1, v2).
// TODO(gkioxari) explicitly check whether p is coplanar with (v0, v1, v2)
// and throw an error if check fails
//
// Args:
// p: vec3 coordinates of a point
// v0, v1, v2: vec3 coordinates of the triangle vertices
//
// Returns
// bary: (w0, w1, w2) barycentric coordinates
//
__device__ inline float3 BarycentricCoords3Forward(
const float3& p,
const float3& v0,
const float3& v1,
const float3& v2) {
float3 p0 = v1 - v0;
float3 p1 = v2 - v0;
float3 p2 = p - v0;
const float d00 = dot(p0, p0);
const float d01 = dot(p0, p1);
const float d11 = dot(p1, p1);
const float d20 = dot(p2, p0);
const float d21 = dot(p2, p1);
const float denom = d00 * d11 - d01 * d01 + kEpsilon;
const float w1 = (d11 * d20 - d01 * d21) / denom;
const float w2 = (d00 * d21 - d01 * d20) / denom;
const float w0 = 1.0f - w1 - w2;
return make_float3(w0, w1, w2);
}
// Checks whether the point p is inside the triangle (v0, v1, v2).
// A point is inside the triangle, if all barycentric coordinates
// wrt the triangle are >= 0 & <= 1.
//
// NOTE that this function assumes that p lives on the space spanned
// by (v0, v1, v2).
// TODO(gkioxari) explicitly check whether p is coplanar with (v0, v1, v2)
// and throw an error if check fails
//
// Args:
// p: vec3 coordinates of a point
// v0, v1, v2: vec3 coordinates of the triangle vertices
//
// Returns:
// inside: bool indicating wether p is inside triangle
//
__device__ inline bool IsInsideTriangle(
const float3& p,
const float3& v0,
const float3& v1,
const float3& v2) {
float3 bary = BarycentricCoords3Forward(p, v0, v1, v2);
bool x_in = 0.0f <= bary.x && bary.x <= 1.0f;
bool y_in = 0.0f <= bary.y && bary.y <= 1.0f;
bool z_in = 0.0f <= bary.z && bary.z <= 1.0f;
bool inside = x_in && y_in && z_in;
return inside;
}
// Computes the minimum squared Euclidean distance between the point p
// and the segment spanned by (v0, v1).
// To find this we parametrize p as: x(t) = v0 + t * (v1 - v0)
// and find t which minimizes (x(t) - p) ^ 2.
// Note that p does not need to live in the space spanned by (v0, v1)
//
// Args:
// p: vec3 coordinates of a point
// v0, v1: vec3 coordinates of start and end of segment
//
// Returns:
// dist: the minimum squared distance of p from segment (v0, v1)
//
__device__ inline float
PointLine3DistanceForward(const float3& p, const float3& v0, const float3& v1) {
const float3 v1v0 = v1 - v0;
const float3 pv0 = p - v0;
const float t_bot = dot(v1v0, v1v0);
const float t_top = dot(pv0, v1v0);
// if t_bot small, then v0 == v1, set tt to 0.
float tt = (t_bot < kEpsilon) ? 0.0f : (t_top / t_bot);
tt = __saturatef(tt); // clamps to [0, 1]
const float3 p_proj = v0 + tt * v1v0;
const float3 diff = p - p_proj;
const float dist = dot(diff, diff);
return dist;
}
// Backward function of the minimum squared Euclidean distance between the point
// p and the line segment (v0, v1).
//
// Args:
// p: vec3 coordinates of a point
// v0, v1: vec3 coordinates of start and end of segment
// grad_dist: Float of the gradient wrt dist
//
// Returns:
// tuple of gradients for the point and line segment (v0, v1):
// (float3 grad_p, float3 grad_v0, float3 grad_v1)
__device__ inline thrust::tuple<float3, float3, float3>
PointLine3DistanceBackward(
const float3& p,
const float3& v0,
const float3& v1,
const float& grad_dist) {
const float3 v1v0 = v1 - v0;
const float3 pv0 = p - v0;
const float t_bot = dot(v1v0, v1v0);
const float t_top = dot(v1v0, pv0);
float3 grad_p = make_float3(0.0f, 0.0f, 0.0f);
float3 grad_v0 = make_float3(0.0f, 0.0f, 0.0f);
float3 grad_v1 = make_float3(0.0f, 0.0f, 0.0f);
const float tt = t_top / t_bot;
if (t_bot < kEpsilon) {
// if t_bot small, then v0 == v1,
// and dist = 0.5 * dot(pv0, pv0) + 0.5 * dot(pv1, pv1)
grad_p = grad_dist * 2.0f * pv0;
grad_v0 = -0.5f * grad_p;
grad_v1 = grad_v0;
} else if (tt < 0.0f) {
grad_p = grad_dist * 2.0f * pv0;
grad_v0 = -1.0f * grad_p;
// no gradients wrt v1
} else if (tt > 1.0f) {
grad_p = grad_dist * 2.0f * (p - v1);
grad_v1 = -1.0f * grad_p;
// no gradients wrt v0
} else {
const float3 p_proj = v0 + tt * v1v0;
const float3 diff = p - p_proj;
const float3 grad_base = grad_dist * 2.0f * diff;
grad_p = grad_base - dot(grad_base, v1v0) * v1v0 / t_bot;
const float3 dtt_v0 = (-1.0f * v1v0 - pv0 + 2.0f * tt * v1v0) / t_bot;
grad_v0 = (-1.0f + tt) * grad_base - dot(grad_base, v1v0) * dtt_v0;
const float3 dtt_v1 = (pv0 - 2.0f * tt * v1v0) / t_bot;
grad_v1 = -dot(grad_base, v1v0) * dtt_v1 - tt * grad_base;
}
return thrust::make_tuple(grad_p, grad_v0, grad_v1);
}
// Computes the squared distance of a point p relative to a triangle (v0, v1,
// v2). If the point's projection p0 on the plane spanned by (v0, v1, v2) is
// inside the triangle with vertices (v0, v1, v2), then the returned value is
// the squared distance of p to its projection p0. Otherwise, the returned value
// is the smallest squared distance of p from the line segments (v0, v1), (v0,
// v2) and (v1, v2).
//
// Args:
// p: vec3 coordinates of a point
// v0, v1, v2: vec3 coordinates of the triangle vertices
//
// Returns:
// dist: Float of the squared distance
//
__device__ inline float PointTriangle3DistanceForward(
const float3& p,
const float3& v0,
const float3& v1,
const float3& v2) {
float3 normal = cross(v2 - v0, v1 - v0);
const float norm_normal = norm(normal);
normal = normalize(normal);
// p0 is the projection of p on the plane spanned by (v0, v1, v2)
// i.e. p0 = p + t * normal, s.t. (p0 - v0) is orthogonal to normal
const float t = dot(v0 - p, normal);
const float3 p0 = p + t * normal;
bool is_inside = IsInsideTriangle(p0, v0, v1, v2);
float dist = 0.0f;
if ((is_inside) && (norm_normal > kEpsilon)) {
// if projection p0 is inside triangle spanned by (v0, v1, v2)
// then distance is equal to norm(p0 - p)^2
dist = t * t;
} else {
const float e01 = PointLine3DistanceForward(p, v0, v1);
const float e02 = PointLine3DistanceForward(p, v0, v2);
const float e12 = PointLine3DistanceForward(p, v1, v2);
dist = (e01 > e02) ? e02 : e01;
dist = (dist > e12) ? e12 : dist;
}
return dist;
}
// The backward pass for computing the squared distance of a point
// to the triangle (v0, v1, v2).
//
// Args:
// p: xyz coordinates of a point
// v0, v1, v2: xyz coordinates of the triangle vertices
// grad_dist: Float of the gradient wrt dist
//
// Returns:
// tuple of gradients for the point and triangle:
// (float3 grad_p, float3 grad_v0, float3 grad_v1, float3 grad_v2)
//
__device__ inline thrust::tuple<float3, float3, float3, float3>
PointTriangle3DistanceBackward(
const float3& p,
const float3& v0,
const float3& v1,
const float3& v2,
const float& grad_dist) {
const float3 v2v0 = v2 - v0;
const float3 v1v0 = v1 - v0;
const float3 v0p = v0 - p;
float3 raw_normal = cross(v2v0, v1v0);
const float norm_normal = norm(raw_normal);
float3 normal = normalize(raw_normal);
// p0 is the projection of p on the plane spanned by (v0, v1, v2)
// i.e. p0 = p + t * normal, s.t. (p0 - v0) is orthogonal to normal
const float t = dot(v0 - p, normal);
const float3 p0 = p + t * normal;
const float3 diff = t * normal;
bool is_inside = IsInsideTriangle(p0, v0, v1, v2);
float3 grad_p = make_float3(0.0f, 0.0f, 0.0f);
float3 grad_v0 = make_float3(0.0f, 0.0f, 0.0f);
float3 grad_v1 = make_float3(0.0f, 0.0f, 0.0f);
float3 grad_v2 = make_float3(0.0f, 0.0f, 0.0f);
if ((is_inside) && (norm_normal > kEpsilon)) {
// derivative of dist wrt p
grad_p = -2.0f * grad_dist * t * normal;
// derivative of dist wrt normal
const float3 grad_normal = 2.0f * grad_dist * t * (v0p + diff);
// derivative of dist wrt raw_normal
const float3 grad_raw_normal = normalize_backward(raw_normal, grad_normal);
// derivative of dist wrt v2v0 and v1v0
const auto grad_cross = cross_backward(v2v0, v1v0, grad_raw_normal);
const float3 grad_cross_v2v0 = thrust::get<0>(grad_cross);
const float3 grad_cross_v1v0 = thrust::get<1>(grad_cross);
grad_v0 =
grad_dist * 2.0f * t * normal - (grad_cross_v2v0 + grad_cross_v1v0);
grad_v1 = grad_cross_v1v0;
grad_v2 = grad_cross_v2v0;
} else {
const float e01 = PointLine3DistanceForward(p, v0, v1);
const float e02 = PointLine3DistanceForward(p, v0, v2);
const float e12 = PointLine3DistanceForward(p, v1, v2);
if ((e01 <= e02) && (e01 <= e12)) {
// e01 is smallest
const auto grads = PointLine3DistanceBackward(p, v0, v1, grad_dist);
grad_p = thrust::get<0>(grads);
grad_v0 = thrust::get<1>(grads);
grad_v1 = thrust::get<2>(grads);
} else if ((e02 <= e01) && (e02 <= e12)) {
// e02 is smallest
const auto grads = PointLine3DistanceBackward(p, v0, v2, grad_dist);
grad_p = thrust::get<0>(grads);
grad_v0 = thrust::get<1>(grads);
grad_v2 = thrust::get<2>(grads);
} else if ((e12 <= e01) && (e12 <= e02)) {
// e12 is smallest
const auto grads = PointLine3DistanceBackward(p, v1, v2, grad_dist);
grad_p = thrust::get<0>(grads);
grad_v1 = thrust::get<1>(grads);
grad_v2 = thrust::get<2>(grads);
}
}
return thrust::make_tuple(grad_p, grad_v0, grad_v1, grad_v2);
}

View File

@@ -0,0 +1,398 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <ATen/ATen.h>
#include <algorithm>
#include <type_traits>
#include "vec2.h"
#include "vec3.h"
// Set epsilon for preventing floating point errors and division by 0.
const auto kEpsilon = 1e-8;
// Determines whether a point p is on the right side of a 2D line segment
// given by the end points v0, v1.
//
// Args:
// p: vec2 Coordinates of a point.
// v0, v1: vec2 Coordinates of the end points of the edge.
//
// Returns:
// area: The signed area of the parallelogram given by the vectors
// A = p - v0
// B = v1 - v0
//
// v1 ________
// /\ /
// A / \ /
// / \ /
// v0 /______\/
// B p
//
// The area can also be interpreted as the cross product A x B.
// If the sign of the area is positive, the point p is on the
// right side of the edge. Negative area indicates the point is on
// the left side of the edge. i.e. for an edge v1 - v0:
//
// v1
// /
// /
// - / +
// /
// /
// v0
//
template <typename T>
T EdgeFunctionForward(const vec2<T>& p, const vec2<T>& v0, const vec2<T>& v1) {
const T edge = (p.x - v0.x) * (v1.y - v0.y) - (p.y - v0.y) * (v1.x - v0.x);
return edge;
}
// Backward pass for the edge function returning partial dervivatives for each
// of the input points.
//
// Args:
// p: vec2 Coordinates of a point.
// v0, v1: vec2 Coordinates of the end points of the edge.
// grad_edge: Upstream gradient for output from edge function.
//
// Returns:
// tuple of gradients for each of the input points:
// (vec2<T> d_edge_dp, vec2<T> d_edge_dv0, vec2<T> d_edge_dv1)
//
template <typename T>
inline std::tuple<vec2<T>, vec2<T>, vec2<T>> EdgeFunctionBackward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const T grad_edge) {
const vec2<T> dedge_dp(v1.y - v0.y, v0.x - v1.x);
const vec2<T> dedge_dv0(p.y - v1.y, v1.x - p.x);
const vec2<T> dedge_dv1(v0.y - p.y, p.x - v0.x);
return std::make_tuple(
grad_edge * dedge_dp, grad_edge * dedge_dv0, grad_edge * dedge_dv1);
}
// The forward pass for computing the barycentric coordinates of a point
// relative to a triangle.
// Ref:
// https://www.scratchapixel.com/lessons/3d-basic-rendering/ray-tracing-rendering-a-triangle/barycentric-coordinates
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the triangle vertices.
//
// Returns
// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1].
//
template <typename T>
vec3<T> BarycentricCoordinatesForward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const vec2<T>& v2) {
const T area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
const T w0 = EdgeFunctionForward(p, v1, v2) / area;
const T w1 = EdgeFunctionForward(p, v2, v0) / area;
const T w2 = EdgeFunctionForward(p, v0, v1) / area;
return vec3<T>(w0, w1, w2);
}
// The backward pass for computing the barycentric coordinates of a point
// relative to a triangle.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: (x, y) coordinates of the triangle vertices.
// grad_bary_upstream: vec3<T> Upstream gradient for each of the
// barycentric coordaintes [grad_w0, grad_w1, grad_w2].
//
// Returns
// tuple of gradients for each of the triangle vertices:
// (vec2<T> grad_v0, vec2<T> grad_v1, vec2<T> grad_v2)
//
template <typename T>
inline std::tuple<vec2<T>, vec2<T>, vec2<T>, vec2<T>> BarycentricCoordsBackward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const vec2<T>& v2,
const vec3<T>& grad_bary_upstream) {
const T area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
const T area2 = pow(area, 2.0f);
const T area_inv = 1.0f / area;
const T e0 = EdgeFunctionForward(p, v1, v2);
const T e1 = EdgeFunctionForward(p, v2, v0);
const T e2 = EdgeFunctionForward(p, v0, v1);
const T grad_w0 = grad_bary_upstream.x;
const T grad_w1 = grad_bary_upstream.y;
const T grad_w2 = grad_bary_upstream.z;
// Calculate component of the gradient from each of w0, w1 and w2.
// e.g. for w0:
// dloss/dw0_v = dl/dw0 * dw0/dw0_top * dw0_top/dv
// + dl/dw0 * dw0/dw0_bot * dw0_bot/dv
const T dw0_darea = -e0 / (area2);
const T dw0_e0 = area_inv;
const T dloss_d_w0area = grad_w0 * dw0_darea;
const T dloss_e0 = grad_w0 * dw0_e0;
auto de0_dv = EdgeFunctionBackward(p, v1, v2, dloss_e0);
auto dw0area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w0area);
const vec2<T> dw0_p = std::get<0>(de0_dv);
const vec2<T> dw0_dv0 = std::get<1>(dw0area_dv);
const vec2<T> dw0_dv1 = std::get<1>(de0_dv) + std::get<2>(dw0area_dv);
const vec2<T> dw0_dv2 = std::get<2>(de0_dv) + std::get<0>(dw0area_dv);
const T dw1_darea = -e1 / (area2);
const T dw1_e1 = area_inv;
const T dloss_d_w1area = grad_w1 * dw1_darea;
const T dloss_e1 = grad_w1 * dw1_e1;
auto de1_dv = EdgeFunctionBackward(p, v2, v0, dloss_e1);
auto dw1area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w1area);
const vec2<T> dw1_p = std::get<0>(de1_dv);
const vec2<T> dw1_dv0 = std::get<2>(de1_dv) + std::get<1>(dw1area_dv);
const vec2<T> dw1_dv1 = std::get<2>(dw1area_dv);
const vec2<T> dw1_dv2 = std::get<1>(de1_dv) + std::get<0>(dw1area_dv);
const T dw2_darea = -e2 / (area2);
const T dw2_e2 = area_inv;
const T dloss_d_w2area = grad_w2 * dw2_darea;
const T dloss_e2 = grad_w2 * dw2_e2;
auto de2_dv = EdgeFunctionBackward(p, v0, v1, dloss_e2);
auto dw2area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w2area);
const vec2<T> dw2_p = std::get<0>(de2_dv);
const vec2<T> dw2_dv0 = std::get<1>(de2_dv) + std::get<1>(dw2area_dv);
const vec2<T> dw2_dv1 = std::get<2>(de2_dv) + std::get<2>(dw2area_dv);
const vec2<T> dw2_dv2 = std::get<0>(dw2area_dv);
const vec2<T> dbary_p = dw0_p + dw1_p + dw2_p;
const vec2<T> dbary_dv0 = dw0_dv0 + dw1_dv0 + dw2_dv0;
const vec2<T> dbary_dv1 = dw0_dv1 + dw1_dv1 + dw2_dv1;
const vec2<T> dbary_dv2 = dw0_dv2 + dw1_dv2 + dw2_dv2;
return std::make_tuple(dbary_p, dbary_dv0, dbary_dv1, dbary_dv2);
}
// Forward pass for applying perspective correction to barycentric coordinates.
//
// Args:
// bary: Screen-space barycentric coordinates for a point
// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
//
// Returns
// World-space barycentric coordinates
//
template <typename T>
inline vec3<T> BarycentricPerspectiveCorrectionForward(
const vec3<T>& bary,
const T z0,
const T z1,
const T z2) {
const T w0_top = bary.x * z1 * z2;
const T w1_top = bary.y * z0 * z2;
const T w2_top = bary.z * z0 * z1;
const T denom = w0_top + w1_top + w2_top;
const T w0 = w0_top / denom;
const T w1 = w1_top / denom;
const T w2 = w2_top / denom;
return vec3<T>(w0, w1, w2);
}
// Backward pass for applying perspective correction to barycentric coordinates.
//
// Args:
// bary: Screen-space barycentric coordinates for a point
// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
// grad_out: Upstream gradient of the loss with respect to the corrected
// barycentric coordinates.
//
// Returns a tuple of:
// grad_bary: Downstream gradient of the loss with respect to the the
// uncorrected barycentric coordinates.
// grad_z0, grad_z1, grad_z2: Downstream gradient of the loss with respect
// to the z-coordinates of the triangle verts
template <typename T>
inline std::tuple<vec3<T>, T, T, T> BarycentricPerspectiveCorrectionBackward(
const vec3<T>& bary,
const T z0,
const T z1,
const T z2,
const vec3<T>& grad_out) {
// Recompute forward pass
const T w0_top = bary.x * z1 * z2;
const T w1_top = bary.y * z0 * z2;
const T w2_top = bary.z * z0 * z1;
const T denom = w0_top + w1_top + w2_top;
// Now do backward pass
const T grad_denom_top =
-w0_top * grad_out.x - w1_top * grad_out.y - w2_top * grad_out.z;
const T grad_denom = grad_denom_top / (denom * denom);
const T grad_w0_top = grad_denom + grad_out.x / denom;
const T grad_w1_top = grad_denom + grad_out.y / denom;
const T grad_w2_top = grad_denom + grad_out.z / denom;
const T grad_bary_x = grad_w0_top * z1 * z2;
const T grad_bary_y = grad_w1_top * z0 * z2;
const T grad_bary_z = grad_w2_top * z0 * z1;
const vec3<T> grad_bary(grad_bary_x, grad_bary_y, grad_bary_z);
const T grad_z0 = grad_w1_top * bary.y * z2 + grad_w2_top * bary.z * z1;
const T grad_z1 = grad_w0_top * bary.x * z2 + grad_w2_top * bary.z * z0;
const T grad_z2 = grad_w0_top * bary.x * z1 + grad_w1_top * bary.y * z0;
return std::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
}
// Calculate minimum squared distance between a line segment (v1 - v0) and a
// point p.
//
// Args:
// p: Coordinates of a point.
// v0, v1: Coordinates of the end points of the line segment.
//
// Returns:
// squared distance of the point to the line.
//
// Consider the line extending the segment - this can be parameterized as:
// v0 + t (v1 - v0).
//
// First find the projection of point p onto the line. It falls where:
// t = [(p - v0) . (v1 - v0)] / |v1 - v0|^2
// where . is the dot product.
//
// The parameter t is clamped from [0, 1] to handle points outside the
// segment (v1 - v0).
//
// Once the projection of the point on the segment is known, the distance from
// p to the projection gives the minimum distance to the segment.
//
template <typename T>
T PointLineDistanceForward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1) {
const vec2<T> v1v0 = v1 - v0;
const T l2 = dot(v1v0, v1v0);
if (l2 <= kEpsilon) {
return dot(p - v1, p - v1);
}
const T t = dot(v1v0, p - v0) / l2;
const T tt = std::min(std::max(t, 0.00f), 1.00f);
const vec2<T> p_proj = v0 + tt * v1v0;
return dot(p - p_proj, p - p_proj);
}
// Backward pass for point to line distance in 2D.
//
// Args:
// p: Coordinates of a point.
// v0, v1: Coordinates of the end points of the line segment.
// grad_dist: Upstream gradient for the distance.
//
// Returns:
// tuple of gradients for each of the input points:
// (vec2<T> grad_p, vec2<T> grad_v0, vec2<T> grad_v1)
//
template <typename T>
inline std::tuple<vec2<T>, vec2<T>, vec2<T>> PointLineDistanceBackward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const T& grad_dist) {
// Redo some of the forward pass calculations.
const vec2<T> v1v0 = v1 - v0;
const vec2<T> pv0 = p - v0;
const T t_bot = dot(v1v0, v1v0);
const T t_top = dot(v1v0, pv0);
const T t = t_top / t_bot;
const T tt = std::min(std::max(t, 0.00f), 1.00f);
const vec2<T> p_proj = (1.0f - tt) * v0 + tt * v1;
const vec2<T> grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);
const vec2<T> grad_v1 = grad_dist * tt * 2.0f * (p_proj - p);
const vec2<T> grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
return std::make_tuple(grad_p, grad_v0, grad_v1);
}
// The forward pass for calculating the shortest distance between a point
// and a triangle.
// Ref: https://www.randygaul.net/2014/07/23/distance-point-to-line-segment/
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the three triangle vertices.
//
// Returns:
// shortest squared distance from a point to a triangle.
//
//
template <typename T>
T PointTriangleDistanceForward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const vec2<T>& v2) {
// Compute distance of point to 3 edges of the triangle and return the
// minimum value.
const T e01_dist = PointLineDistanceForward(p, v0, v1);
const T e02_dist = PointLineDistanceForward(p, v0, v2);
const T e12_dist = PointLineDistanceForward(p, v1, v2);
const T edge_dist = std::min(std::min(e01_dist, e02_dist), e12_dist);
return edge_dist;
}
// Backward pass for point triangle distance.
//
// Args:
// p: Coordinates of a point.
// v0, v1, v2: Coordinates of the three triangle vertices.
// grad_dist: Upstream gradient for the distance.
//
// Returns:
// tuple of gradients for each of the triangle vertices:
// (vec2<T> grad_v0, vec2<T> grad_v1, vec2<T> grad_v2)
//
template <typename T>
inline std::tuple<vec2<T>, vec2<T>, vec2<T>, vec2<T>>
PointTriangleDistanceBackward(
const vec2<T>& p,
const vec2<T>& v0,
const vec2<T>& v1,
const vec2<T>& v2,
const T& grad_dist) {
// Compute distance to all 3 edges of the triangle.
const T e01_dist = PointLineDistanceForward(p, v0, v1);
const T e02_dist = PointLineDistanceForward(p, v0, v2);
const T e12_dist = PointLineDistanceForward(p, v1, v2);
// Initialize output tensors.
vec2<T> grad_v0(0.0f, 0.0f);
vec2<T> grad_v1(0.0f, 0.0f);
vec2<T> grad_v2(0.0f, 0.0f);
vec2<T> grad_p(0.0f, 0.0f);
// Find which edge is the closest and return PointLineDistanceBackward for
// that edge.
if (e01_dist <= e02_dist && e01_dist <= e12_dist) {
// Closest edge is v1 - v0.
auto grad_e01 = PointLineDistanceBackward(p, v0, v1, grad_dist);
grad_p = std::get<0>(grad_e01);
grad_v0 = std::get<1>(grad_e01);
grad_v1 = std::get<2>(grad_e01);
} else if (e02_dist <= e01_dist && e02_dist <= e12_dist) {
// Closest edge is v2 - v0.
auto grad_e02 = PointLineDistanceBackward(p, v0, v2, grad_dist);
grad_p = std::get<0>(grad_e02);
grad_v0 = std::get<1>(grad_e02);
grad_v2 = std::get<2>(grad_e02);
} else if (e12_dist <= e01_dist && e12_dist <= e02_dist) {
// Closest edge is v2 - v1.
auto grad_e12 = PointLineDistanceBackward(p, v1, v2, grad_dist);
grad_p = std::get<0>(grad_e12);
grad_v1 = std::get<1>(grad_e12);
grad_v2 = std::get<2>(grad_e12);
}
return std::make_tuple(grad_p, grad_v0, grad_v1, grad_v2);
}

View File

@@ -0,0 +1,218 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
// This converts dynamic array lookups into static array lookups, for small
// arrays up to size 32.
//
// Suppose we have a small thread-local array:
//
// float vals[10];
//
// Ideally we should only index this array using static indices:
//
// for (int i = 0; i < 10; ++i) vals[i] = i * i;
//
// If we do so, then the CUDA compiler may be able to place the array into
// registers, which can have a big performance improvement. However if we
// access the array dynamically, the the compiler may force the array into
// local memory, which has the same latency as global memory.
//
// These functions convert dynamic array access into static array access
// using a brute-force lookup table. It can be used like this:
//
// float vals[10];
// int idx = 3;
// float val = 3.14f;
// RegisterIndexUtils<float, 10>::set(vals, idx, val);
// float val2 = RegisterIndexUtils<float, 10>::get(vals, idx);
//
// The implementation is based on fbcuda/RegisterUtils.cuh:
// https://github.com/facebook/fbcuda/blob/master/RegisterUtils.cuh
// To avoid depending on the entire library, we just reimplement these two
// functions. The fbcuda implementation is a bit more sophisticated, and uses
// the preprocessor to generate switch statements that go up to N for each
// value of N. We are lazy and just have a giant explicit switch statement.
//
// We might be able to use a template metaprogramming approach similar to
// DispatchKernel1D for this. However DispatchKernel1D is intended to be used
// for dispatching to the correct CUDA kernel on the host, while this is
// is intended to run on the device. I was concerned that a metaprogramming
// approach for this might lead to extra function calls at runtime if the
// compiler fails to optimize them away, which could be very slow on device.
// However I didn't actually benchmark or test this.
template <typename T, int N>
struct RegisterIndexUtils {
__device__ __forceinline__ static T get(const T arr[N], int idx) {
if (idx < 0 || idx >= N)
return T();
switch (idx) {
case 0:
return arr[0];
case 1:
return arr[1];
case 2:
return arr[2];
case 3:
return arr[3];
case 4:
return arr[4];
case 5:
return arr[5];
case 6:
return arr[6];
case 7:
return arr[7];
case 8:
return arr[8];
case 9:
return arr[9];
case 10:
return arr[10];
case 11:
return arr[11];
case 12:
return arr[12];
case 13:
return arr[13];
case 14:
return arr[14];
case 15:
return arr[15];
case 16:
return arr[16];
case 17:
return arr[17];
case 18:
return arr[18];
case 19:
return arr[19];
case 20:
return arr[20];
case 21:
return arr[21];
case 22:
return arr[22];
case 23:
return arr[23];
case 24:
return arr[24];
case 25:
return arr[25];
case 26:
return arr[26];
case 27:
return arr[27];
case 28:
return arr[28];
case 29:
return arr[29];
case 30:
return arr[30];
case 31:
return arr[31];
};
return T();
}
__device__ __forceinline__ static void set(T arr[N], int idx, T val) {
if (idx < 0 || idx >= N)
return;
switch (idx) {
case 0:
arr[0] = val;
break;
case 1:
arr[1] = val;
break;
case 2:
arr[2] = val;
break;
case 3:
arr[3] = val;
break;
case 4:
arr[4] = val;
break;
case 5:
arr[5] = val;
break;
case 6:
arr[6] = val;
break;
case 7:
arr[7] = val;
break;
case 8:
arr[8] = val;
break;
case 9:
arr[9] = val;
break;
case 10:
arr[10] = val;
break;
case 11:
arr[11] = val;
break;
case 12:
arr[12] = val;
break;
case 13:
arr[13] = val;
break;
case 14:
arr[14] = val;
break;
case 15:
arr[15] = val;
break;
case 16:
arr[16] = val;
break;
case 17:
arr[17] = val;
break;
case 18:
arr[18] = val;
break;
case 19:
arr[19] = val;
break;
case 20:
arr[20] = val;
break;
case 21:
arr[21] = val;
break;
case 22:
arr[22] = val;
break;
case 23:
arr[23] = val;
break;
case 24:
arr[24] = val;
break;
case 25:
arr[25] = val;
break;
case 26:
arr[26] = val;
break;
case 27:
arr[27] = val;
break;
case 28:
arr[28] = val;
break;
case 29:
arr[29] = val;
break;
case 30:
arr[30] = val;
break;
case 31:
arr[31] = val;
break;
}
}
};

View File

@@ -0,0 +1,159 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#define MINK_H
#include "index_utils.cuh"
// A data structure to keep track of the smallest K keys seen so far as well
// as their associated values, intended to be used in device code.
// This data structure doesn't allocate any memory; keys and values are stored
// in arrays passed to the constructor.
//
// The implementation is generic; it can be used for any key type that supports
// the < operator, and can be used with any value type.
//
// Example usage:
//
// float keys[K];
// int values[K];
// MinK<float, int> mink(keys, values, K);
// for (...) {
// // Produce some key and value from somewhere
// mink.add(key, value);
// }
// mink.sort();
//
// Now keys and values store the smallest K keys seen so far and the values
// associated to these keys:
//
// for (int k = 0; k < K; ++k) {
// float key_k = keys[k];
// int value_k = values[k];
// }
template <typename key_t, typename value_t>
class MinK {
public:
// Constructor.
//
// Arguments:
// keys: Array in which to store keys
// values: Array in which to store values
// K: How many values to keep track of
__device__ MinK(key_t* keys, value_t* vals, int K)
: keys(keys), vals(vals), K(K), _size(0) {}
// Try to add a new key and associated value to the data structure. If the key
// is one of the smallest K seen so far then it will be kept; otherwise it
// it will not be kept.
//
// This takes O(1) operations if the new key is not kept, or if the structure
// currently contains fewer than K elements. Otherwise this takes O(K) time.
//
// Arguments:
// key: The key to add
// val: The value associated to the key
__device__ __forceinline__ void add(const key_t& key, const value_t& val) {
if (_size < K) {
keys[_size] = key;
vals[_size] = val;
if (_size == 0 || key > max_key) {
max_key = key;
max_idx = _size;
}
_size++;
} else if (key < max_key) {
keys[max_idx] = key;
vals[max_idx] = val;
max_key = key;
for (int k = 0; k < K; ++k) {
key_t cur_key = keys[k];
if (cur_key > max_key) {
max_key = cur_key;
max_idx = k;
}
}
}
}
// Get the number of items currently stored in the structure.
// This takes O(1) time.
__device__ __forceinline__ int size() {
return _size;
}
// Sort the items stored in the structure using bubble sort.
// This takes O(K^2) time.
__device__ __forceinline__ void sort() {
for (int i = 0; i < _size - 1; ++i) {
for (int j = 0; j < _size - i - 1; ++j) {
if (keys[j + 1] < keys[j]) {
key_t key = keys[j];
value_t val = vals[j];
keys[j] = keys[j + 1];
vals[j] = vals[j + 1];
keys[j + 1] = key;
vals[j + 1] = val;
}
}
}
}
private:
key_t* keys;
value_t* vals;
int K;
int _size;
key_t max_key;
int max_idx;
};
// This is a version of MinK that only touches the arrays using static indexing
// via RegisterIndexUtils. If the keys and values are stored in thread-local
// arrays, then this may allow the compiler to place them in registers for
// fast access.
//
// This has the same API as RegisterMinK, but doesn't support sorting.
// We found that sorting via RegisterIndexUtils gave very poor performance,
// and suspect it may have prevented the compiler from placing the arrays
// into registers.
template <typename key_t, typename value_t, int K>
class RegisterMinK {
public:
__device__ RegisterMinK(key_t* keys, value_t* vals)
: keys(keys), vals(vals), _size(0) {}
__device__ __forceinline__ void add(const key_t& key, const value_t& val) {
if (_size < K) {
RegisterIndexUtils<key_t, K>::set(keys, _size, key);
RegisterIndexUtils<value_t, K>::set(vals, _size, val);
if (_size == 0 || key > max_key) {
max_key = key;
max_idx = _size;
}
_size++;
} else if (key < max_key) {
RegisterIndexUtils<key_t, K>::set(keys, max_idx, key);
RegisterIndexUtils<value_t, K>::set(vals, max_idx, val);
max_key = key;
for (int k = 0; k < K; ++k) {
key_t cur_key = RegisterIndexUtils<key_t, K>::get(keys, k);
if (cur_key > max_key) {
max_key = cur_key;
max_idx = k;
}
}
}
}
__device__ __forceinline__ int size() {
return _size;
}
private:
key_t* keys;
value_t* vals;
int _size;
key_t max_key;
int max_idx;
};

View File

@@ -0,0 +1,11 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <torch/extension.h>
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x "must be a CUDA tensor.")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x "must be contiguous.")
#define CHECK_CONTIGUOUS_CUDA(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)

View File

@@ -0,0 +1,59 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
#include <type_traits>
// A fixed-sized vector with basic arithmetic operators useful for
// representing 2D coordinates.
// TODO: switch to Eigen if more functionality is needed.
template <
typename T,
typename = std::enable_if_t<
std::is_same<T, double>::value || std::is_same<T, float>::value>>
struct vec2 {
T x, y;
typedef T scalar_t;
vec2(T x, T y) : x(x), y(y) {}
};
template <typename T>
inline vec2<T> operator+(const vec2<T>& a, const vec2<T>& b) {
return vec2<T>(a.x + b.x, a.y + b.y);
}
template <typename T>
inline vec2<T> operator-(const vec2<T>& a, const vec2<T>& b) {
return vec2<T>(a.x - b.x, a.y - b.y);
}
template <typename T>
inline vec2<T> operator*(const T a, const vec2<T>& b) {
return vec2<T>(a * b.x, a * b.y);
}
template <typename T>
inline vec2<T> operator/(const vec2<T>& a, const T b) {
if (b == 0.0) {
AT_ERROR(
"denominator in vec2 division is 0"); // prevent divide by 0 errors.
}
return vec2<T>(a.x / b, a.y / b);
}
template <typename T>
inline T dot(const vec2<T>& a, const vec2<T>& b) {
return a.x * b.x + a.y * b.y;
}
template <typename T>
inline T norm(const vec2<T>& a, const vec2<T>& b) {
const vec2<T> ba = b - a;
return sqrt(dot(ba, ba));
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const vec2<T>& v) {
os << "vec2(" << v.x << ", " << v.y << ")";
return os;
}

View File

@@ -0,0 +1,63 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#pragma once
// A fixed-sized vector with basic arithmetic operators useful for
// representing 3D coordinates.
// TODO: switch to Eigen if more functionality is needed.
template <
typename T,
typename = std::enable_if_t<
std::is_same<T, double>::value || std::is_same<T, float>::value>>
struct vec3 {
T x, y, z;
typedef T scalar_t;
vec3(T x, T y, T z) : x(x), y(y), z(z) {}
};
template <typename T>
inline vec3<T> operator+(const vec3<T>& a, const vec3<T>& b) {
return vec3<T>(a.x + b.x, a.y + b.y, a.z + b.z);
}
template <typename T>
inline vec3<T> operator-(const vec3<T>& a, const vec3<T>& b) {
return vec3<T>(a.x - b.x, a.y - b.y, a.z - b.z);
}
template <typename T>
inline vec3<T> operator/(const vec3<T>& a, const T b) {
if (b == 0.0) {
AT_ERROR(
"denominator in vec3 division is 0"); // prevent divide by 0 errors.
}
return vec3<T>(a.x / b, a.y / b, a.z / b);
}
template <typename T>
inline vec3<T> operator*(const T a, const vec3<T>& b) {
return vec3<T>(a * b.x, a * b.y, a * b.z);
}
template <typename T>
inline vec3<T> operator*(const vec3<T>& a, const vec3<T>& b) {
return vec3<T>(a.x * b.x, a.y * b.y, a.z * b.z);
}
template <typename T>
inline T dot(const vec3<T>& a, const vec3<T>& b) {
return a.x * b.x + a.y * b.y + a.z * b.z;
}
template <typename T>
inline vec3<T> cross(const vec3<T>& a, const vec3<T>& b) {
return vec3<T>(
a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x);
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const vec3<T>& v) {
os << "vec3(" << v.x << ", " << v.y << ", " << v.z << ")";
return os;
}

View File

@@ -0,0 +1,44 @@
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#include <float.h>
#include <math.h>
#include <cstdio>
// helper WarpReduce used in .cu files
template <typename scalar_t>
__device__ void WarpReduce(
volatile scalar_t* min_dists,
volatile int64_t* min_idxs,
const size_t tid) {
// s = 32
if (min_dists[tid] > min_dists[tid + 32]) {
min_idxs[tid] = min_idxs[tid + 32];
min_dists[tid] = min_dists[tid + 32];
}
// s = 16
if (min_dists[tid] > min_dists[tid + 16]) {
min_idxs[tid] = min_idxs[tid + 16];
min_dists[tid] = min_dists[tid + 16];
}
// s = 8
if (min_dists[tid] > min_dists[tid + 8]) {
min_idxs[tid] = min_idxs[tid + 8];
min_dists[tid] = min_dists[tid + 8];
}
// s = 4
if (min_dists[tid] > min_dists[tid + 4]) {
min_idxs[tid] = min_idxs[tid + 4];
min_dists[tid] = min_dists[tid + 4];
}
// s = 2
if (min_dists[tid] > min_dists[tid + 2]) {
min_idxs[tid] = min_idxs[tid + 2];
min_dists[tid] = min_dists[tid + 2];
}
// s = 1
if (min_dists[tid] > min_dists[tid + 1]) {
min_idxs[tid] = min_idxs[tid + 1];
min_dists[tid] = min_dists[tid + 1];
}
}