mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-04-10 06:25:59 +08:00
point mesh distances
Summary: Implementation of point to mesh distances. The current diff contains two types: (a) Point to Edge (b) Point to Face ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- POINT_MESH_EDGE_4_100_300_5000_cuda:0 2745 3138 183 POINT_MESH_EDGE_4_100_300_10000_cuda:0 4408 4499 114 POINT_MESH_EDGE_4_100_3000_5000_cuda:0 4978 5070 101 POINT_MESH_EDGE_4_100_3000_10000_cuda:0 9076 9187 56 POINT_MESH_EDGE_4_1000_300_5000_cuda:0 1411 1487 355 POINT_MESH_EDGE_4_1000_300_10000_cuda:0 4829 5030 104 POINT_MESH_EDGE_4_1000_3000_5000_cuda:0 7539 7620 67 POINT_MESH_EDGE_4_1000_3000_10000_cuda:0 12088 12272 42 POINT_MESH_EDGE_8_100_300_5000_cuda:0 3106 3222 161 POINT_MESH_EDGE_8_100_300_10000_cuda:0 8561 8648 59 POINT_MESH_EDGE_8_100_3000_5000_cuda:0 6932 7021 73 POINT_MESH_EDGE_8_100_3000_10000_cuda:0 24032 24176 21 POINT_MESH_EDGE_8_1000_300_5000_cuda:0 5272 5399 95 POINT_MESH_EDGE_8_1000_300_10000_cuda:0 11348 11430 45 POINT_MESH_EDGE_8_1000_3000_5000_cuda:0 17478 17683 29 POINT_MESH_EDGE_8_1000_3000_10000_cuda:0 25961 26236 20 POINT_MESH_EDGE_16_100_300_5000_cuda:0 8244 8323 61 POINT_MESH_EDGE_16_100_300_10000_cuda:0 18018 18071 28 POINT_MESH_EDGE_16_100_3000_5000_cuda:0 19428 19544 26 POINT_MESH_EDGE_16_100_3000_10000_cuda:0 44967 45135 12 POINT_MESH_EDGE_16_1000_300_5000_cuda:0 7825 7937 64 POINT_MESH_EDGE_16_1000_300_10000_cuda:0 18504 18571 28 POINT_MESH_EDGE_16_1000_3000_5000_cuda:0 65805 66132 8 POINT_MESH_EDGE_16_1000_3000_10000_cuda:0 90885 91089 6 -------------------------------------------------------------------------------- Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- POINT_MESH_FACE_4_100_300_5000_cuda:0 1561 1685 321 POINT_MESH_FACE_4_100_300_10000_cuda:0 2818 2954 178 POINT_MESH_FACE_4_100_3000_5000_cuda:0 15893 16018 32 POINT_MESH_FACE_4_100_3000_10000_cuda:0 16350 16439 31 POINT_MESH_FACE_4_1000_300_5000_cuda:0 3179 3278 158 POINT_MESH_FACE_4_1000_300_10000_cuda:0 2353 2436 213 POINT_MESH_FACE_4_1000_3000_5000_cuda:0 16262 16336 31 POINT_MESH_FACE_4_1000_3000_10000_cuda:0 9334 9448 54 POINT_MESH_FACE_8_100_300_5000_cuda:0 4377 4493 115 POINT_MESH_FACE_8_100_300_10000_cuda:0 9728 9822 52 POINT_MESH_FACE_8_100_3000_5000_cuda:0 26428 26544 19 POINT_MESH_FACE_8_100_3000_10000_cuda:0 42238 43031 12 POINT_MESH_FACE_8_1000_300_5000_cuda:0 3891 3982 129 POINT_MESH_FACE_8_1000_300_10000_cuda:0 5363 5429 94 POINT_MESH_FACE_8_1000_3000_5000_cuda:0 20998 21084 24 POINT_MESH_FACE_8_1000_3000_10000_cuda:0 39711 39897 13 POINT_MESH_FACE_16_100_300_5000_cuda:0 5955 6001 84 POINT_MESH_FACE_16_100_300_10000_cuda:0 12082 12144 42 POINT_MESH_FACE_16_100_3000_5000_cuda:0 44996 45176 12 POINT_MESH_FACE_16_100_3000_10000_cuda:0 73042 73197 7 POINT_MESH_FACE_16_1000_300_5000_cuda:0 8292 8374 61 POINT_MESH_FACE_16_1000_300_10000_cuda:0 19442 19506 26 POINT_MESH_FACE_16_1000_3000_5000_cuda:0 36059 36194 14 POINT_MESH_FACE_16_1000_3000_10000_cuda:0 64644 64822 8 -------------------------------------------------------------------------------- ``` Reviewed By: jcjohnson Differential Revision: D20590462 fbshipit-source-id: 42a39837b514a546ac9471bfaff60eefe7fae829
This commit is contained in:
committed by
Facebook GitHub Bot
parent
474c8b456a
commit
487d4d6607
343
pytorch3d/csrc/utils/dispatch.cuh
Normal file
343
pytorch3d/csrc/utils/dispatch.cuh
Normal file
@@ -0,0 +1,343 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
//
|
||||
// This file provides utilities for dispatching to specialized versions of
|
||||
// functions. This is especially useful for CUDA kernels, since specializing
|
||||
// them to particular input sizes can often allow the compiler to unroll loops
|
||||
// and place arrays into registers, which can give huge performance speedups.
|
||||
//
|
||||
// As an example, suppose we have the following function which is specialized
|
||||
// based on a compile-time int64_t value:
|
||||
//
|
||||
// template<typename T, int64_t x>
|
||||
// struct SquareOffset {
|
||||
// static void run(T y) {
|
||||
// T val = x * x + y;
|
||||
// std::cout << val << std::endl;
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// This function takes one compile-time argument x, and one run-time argument y.
|
||||
// We might want to compile specialized versions of this for x=0, x=1, etc and
|
||||
// then dispatch to the correct one based on the runtime value of x.
|
||||
// One simple way to achieve this is with a lookup table:
|
||||
//
|
||||
// template<typename T>
|
||||
// void DispatchSquareOffset(const int64_t x, T y) {
|
||||
// if (x == 0) {
|
||||
// SquareOffset<T, 0>::run(y);
|
||||
// } else if (x == 1) {
|
||||
// SquareOffset<T, 1>::run(y);
|
||||
// } else if (x == 2) {
|
||||
// SquareOffset<T, 2>::run(y);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// This function takes both x and y as run-time arguments, and dispatches to
|
||||
// different specialized versions of SquareOffset based on the run-time value
|
||||
// of x. This works, but it's tedious and error-prone. If we want to change the
|
||||
// set of x values for which we provide compile-time specializations, then we
|
||||
// will need to do a lot of tedius editing of the dispatch function. Also, if we
|
||||
// want to provide compile-time specializations for another function other than
|
||||
// SquareOffset, we will need to duplicate the entire lookup table.
|
||||
//
|
||||
// To solve these problems, we can use the DispatchKernel1D function provided by
|
||||
// this file instead:
|
||||
//
|
||||
// template<typename T>
|
||||
// void DispatchSquareOffset(const int64_t x, T y) {
|
||||
// constexpr int64_t xmin = 0;
|
||||
// constexpr int64_t xmax = 2;
|
||||
// DispatchKernel1D<SquareOffset, T, xmin, xmax>(x, y);
|
||||
// }
|
||||
//
|
||||
// DispatchKernel1D uses template metaprogramming to compile specialized
|
||||
// versions of SquareOffset for all values of x with xmin <= x <= xmax, and
|
||||
// then dispatches to the correct one based on the run-time value of x. If we
|
||||
// want to change the range of x values for which SquareOffset is specialized
|
||||
// at compile-time, then all we have to do is change the values of the
|
||||
// compile-time constants xmin and xmax.
|
||||
//
|
||||
// This file also allows us to similarly dispatch functions that depend on two
|
||||
// compile-time int64_t values, using the DispatchKernel2D function like this:
|
||||
//
|
||||
// template<typename T, int64_t x, int64_t y>
|
||||
// struct Sum {
|
||||
// static void run(T z, T w) {
|
||||
// T val = x + y + z + w;
|
||||
// std::cout << val << std::endl;
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// template<typename T>
|
||||
// void DispatchSum(const int64_t x, const int64_t y, int z, int w) {
|
||||
// constexpr int64_t xmin = 1;
|
||||
// constexpr int64_t xmax = 3;
|
||||
// constexpr int64_t ymin = 2;
|
||||
// constexpr int64_t ymax = 5;
|
||||
// DispatchKernel2D<Sum, T, xmin, xmax, ymin, ymax>(x, y, z, w);
|
||||
// }
|
||||
//
|
||||
// Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to
|
||||
// compile specialized versions of sum for all values of (x, y) with
|
||||
// xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct
|
||||
// specialized version based on the runtime values of x and y.
|
||||
|
||||
// Define some helper structs in an anonymous namespace.
|
||||
namespace {
|
||||
|
||||
// 1D dispatch: general case.
|
||||
// Kernel is the function we want to dispatch to; it should take a typename and
|
||||
// an int64_t as template args, and it should define a static void function
|
||||
// run which takes any number of arguments of any type.
|
||||
// In order to dispatch, we will take an additional template argument curN,
|
||||
// and increment it via template recursion until it is equal to the run-time
|
||||
// argument N.
|
||||
template <
|
||||
template <typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
int64_t curN,
|
||||
typename... Args>
|
||||
struct DispatchKernelHelper1D {
|
||||
static void run(const int64_t N, Args... args) {
|
||||
if (curN == N) {
|
||||
// The compile-time value curN is equal to the run-time value N, so we
|
||||
// can dispatch to the run method of the Kernel.
|
||||
Kernel<T, curN>::run(args...);
|
||||
} else if (curN < N) {
|
||||
// Increment curN via template recursion
|
||||
DispatchKernelHelper1D<Kernel, T, minN, maxN, curN + 1, Args...>::run(
|
||||
N, args...);
|
||||
}
|
||||
// We shouldn't get here -- throw an error?
|
||||
}
|
||||
};
|
||||
|
||||
// 1D dispatch: Specialization when curN == maxN
|
||||
// We need this base case to avoid infinite template recursion.
|
||||
template <
|
||||
template <typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
typename... Args>
|
||||
struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
|
||||
static void run(const int64_t N, Args... args) {
|
||||
if (N == maxN) {
|
||||
Kernel<T, maxN>::run(args...);
|
||||
}
|
||||
// We shouldn't get here -- throw an error?
|
||||
}
|
||||
};
|
||||
|
||||
// 2D dispatch, general case.
|
||||
// This is similar to the 1D case: we take additional template args curN and
|
||||
// curM, and increment them via template recursion until they are equal to
|
||||
// the run-time values of N and M, at which point we dispatch to the run
|
||||
// method of the kernel.
|
||||
template <
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
int64_t curN,
|
||||
int64_t minM,
|
||||
int64_t maxM,
|
||||
int64_t curM,
|
||||
typename... Args>
|
||||
struct DispatchKernelHelper2D {
|
||||
static void run(const int64_t N, const int64_t M, Args... args) {
|
||||
if (curN == N && curM == M) {
|
||||
Kernel<T, curN, curM>::run(args...);
|
||||
} else if (curN < N && curM < M) {
|
||||
// Increment both curN and curM. This isn't strictly necessary; we could
|
||||
// just increment one or the other at each step. But this helps to cut
|
||||
// on the number of recursive calls we make.
|
||||
DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
curN + 1,
|
||||
minM,
|
||||
maxM,
|
||||
curM + 1,
|
||||
Args...>::run(N, M, args...);
|
||||
} else if (curN < N) {
|
||||
// Increment curN only
|
||||
DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
curN + 1,
|
||||
minM,
|
||||
maxM,
|
||||
curM,
|
||||
Args...>::run(N, M, args...);
|
||||
} else if (curM < M) {
|
||||
// Increment curM only
|
||||
DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
curN,
|
||||
minM,
|
||||
maxM,
|
||||
curM + 1,
|
||||
Args...>::run(N, M, args...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// 2D dispatch, specialization for curN == maxN
|
||||
template <
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
int64_t minM,
|
||||
int64_t maxM,
|
||||
int64_t curM,
|
||||
typename... Args>
|
||||
struct DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
maxN,
|
||||
minM,
|
||||
maxM,
|
||||
curM,
|
||||
Args...> {
|
||||
static void run(const int64_t N, const int64_t M, Args... args) {
|
||||
if (maxN == N && curM == M) {
|
||||
Kernel<T, maxN, curM>::run(args...);
|
||||
} else if (curM < maxM) {
|
||||
DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
maxN,
|
||||
minM,
|
||||
maxM,
|
||||
curM + 1,
|
||||
Args...>::run(N, M, args...);
|
||||
}
|
||||
// We should not get here -- throw an error?
|
||||
}
|
||||
};
|
||||
|
||||
// 2D dispatch, specialization for curM == maxM
|
||||
template <
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
int64_t curN,
|
||||
int64_t minM,
|
||||
int64_t maxM,
|
||||
typename... Args>
|
||||
struct DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
curN,
|
||||
minM,
|
||||
maxM,
|
||||
maxM,
|
||||
Args...> {
|
||||
static void run(const int64_t N, const int64_t M, Args... args) {
|
||||
if (curN == N && maxM == M) {
|
||||
Kernel<T, curN, maxM>::run(args...);
|
||||
} else if (curN < maxN) {
|
||||
DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
curN + 1,
|
||||
minM,
|
||||
maxM,
|
||||
maxM,
|
||||
Args...>::run(N, M, args...);
|
||||
}
|
||||
// We should not get here -- throw an error?
|
||||
}
|
||||
};
|
||||
|
||||
// 2D dispatch, specialization for curN == maxN, curM == maxM
|
||||
template <
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
int64_t minM,
|
||||
int64_t maxM,
|
||||
typename... Args>
|
||||
struct DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
maxN,
|
||||
minM,
|
||||
maxM,
|
||||
maxM,
|
||||
Args...> {
|
||||
static void run(const int64_t N, const int64_t M, Args... args) {
|
||||
if (maxN == N && maxM == M) {
|
||||
Kernel<T, maxN, maxM>::run(args...);
|
||||
}
|
||||
// We should not get here -- throw an error?
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// This is the function we expect users to call to dispatch to 1D functions
|
||||
template <
|
||||
template <typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
typename... Args>
|
||||
void DispatchKernel1D(const int64_t N, Args... args) {
|
||||
if (minN <= N && N <= maxN) {
|
||||
// Kick off the template recursion by calling the Helper with curN = minN
|
||||
DispatchKernelHelper1D<Kernel, T, minN, maxN, minN, Args...>::run(
|
||||
N, args...);
|
||||
}
|
||||
// Maybe throw an error if we tried to dispatch outside the allowed range?
|
||||
}
|
||||
|
||||
// This is the function we expect users to call to dispatch to 2D functions
|
||||
template <
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
int64_t minM,
|
||||
int64_t maxM,
|
||||
typename... Args>
|
||||
void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) {
|
||||
if (minN <= N && N <= maxN && minM <= M && M <= maxM) {
|
||||
// Kick off the template recursion by calling the Helper with curN = minN
|
||||
// and curM = minM
|
||||
DispatchKernelHelper2D<
|
||||
Kernel,
|
||||
T,
|
||||
minN,
|
||||
maxN,
|
||||
minN,
|
||||
minM,
|
||||
maxM,
|
||||
minM,
|
||||
Args...>::run(N, M, args...);
|
||||
}
|
||||
// Maybe throw an error if we tried to dispatch outside the specified range?
|
||||
}
|
||||
139
pytorch3d/csrc/utils/float_math.cuh
Normal file
139
pytorch3d/csrc/utils/float_math.cuh
Normal file
@@ -0,0 +1,139 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <thrust/tuple.h>
|
||||
|
||||
// Set epsilon
|
||||
#ifdef _MSC_VER
|
||||
#define vEpsilon 1e-8f
|
||||
#else
|
||||
const auto vEpsilon = 1e-8;
|
||||
#endif
|
||||
|
||||
// Common functions and operators for float2.
|
||||
|
||||
__device__ inline float2 operator-(const float2& a, const float2& b) {
|
||||
return make_float2(a.x - b.x, a.y - b.y);
|
||||
}
|
||||
|
||||
__device__ inline float2 operator+(const float2& a, const float2& b) {
|
||||
return make_float2(a.x + b.x, a.y + b.y);
|
||||
}
|
||||
|
||||
__device__ inline float2 operator/(const float2& a, const float2& b) {
|
||||
return make_float2(a.x / b.x, a.y / b.y);
|
||||
}
|
||||
|
||||
__device__ inline float2 operator/(const float2& a, const float b) {
|
||||
return make_float2(a.x / b, a.y / b);
|
||||
}
|
||||
|
||||
__device__ inline float2 operator*(const float2& a, const float2& b) {
|
||||
return make_float2(a.x * b.x, a.y * b.y);
|
||||
}
|
||||
|
||||
__device__ inline float2 operator*(const float a, const float2& b) {
|
||||
return make_float2(a * b.x, a * b.y);
|
||||
}
|
||||
|
||||
__device__ inline float dot(const float2& a, const float2& b) {
|
||||
return a.x * b.x + a.y * b.y;
|
||||
}
|
||||
|
||||
// Backward pass for the dot product.
|
||||
// Args:
|
||||
// a, b: Coordinates of two points.
|
||||
// grad_dot: Upstream gradient for the output.
|
||||
//
|
||||
// Returns:
|
||||
// tuple of gradients for each of the input points:
|
||||
// (float2 grad_a, float2 grad_b)
|
||||
//
|
||||
__device__ inline thrust::tuple<float2, float2>
|
||||
DotBackward(const float2& a, const float2& b, const float& grad_dot) {
|
||||
return thrust::make_tuple(grad_dot * b, grad_dot * a);
|
||||
}
|
||||
|
||||
__device__ inline float sum(const float2& a) {
|
||||
return a.x + a.y;
|
||||
}
|
||||
|
||||
// Common functions and operators for float3.
|
||||
|
||||
__device__ inline float3 operator-(const float3& a, const float3& b) {
|
||||
return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
|
||||
}
|
||||
|
||||
__device__ inline float3 operator+(const float3& a, const float3& b) {
|
||||
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
||||
}
|
||||
|
||||
__device__ inline float3 operator/(const float3& a, const float3& b) {
|
||||
return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
|
||||
}
|
||||
|
||||
__device__ inline float3 operator/(const float3& a, const float b) {
|
||||
return make_float3(a.x / b, a.y / b, a.z / b);
|
||||
}
|
||||
|
||||
__device__ inline float3 operator*(const float3& a, const float3& b) {
|
||||
return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
|
||||
}
|
||||
|
||||
__device__ inline float3 operator*(const float a, const float3& b) {
|
||||
return make_float3(a * b.x, a * b.y, a * b.z);
|
||||
}
|
||||
|
||||
__device__ inline float dot(const float3& a, const float3& b) {
|
||||
return a.x * b.x + a.y * b.y + a.z * b.z;
|
||||
}
|
||||
|
||||
__device__ inline float sum(const float3& a) {
|
||||
return a.x + a.y + a.z;
|
||||
}
|
||||
|
||||
__device__ inline float3 cross(const float3& a, const float3& b) {
|
||||
return make_float3(
|
||||
a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x);
|
||||
}
|
||||
|
||||
__device__ inline thrust::tuple<float3, float3>
|
||||
cross_backward(const float3& a, const float3& b, const float3& grad_cross) {
|
||||
const float grad_ax = -grad_cross.y * b.z + grad_cross.z * b.y;
|
||||
const float grad_ay = grad_cross.x * b.z - grad_cross.z * b.x;
|
||||
const float grad_az = -grad_cross.x * b.y + grad_cross.y * b.x;
|
||||
const float3 grad_a = make_float3(grad_ax, grad_ay, grad_az);
|
||||
|
||||
const float grad_bx = grad_cross.y * a.z - grad_cross.z * a.y;
|
||||
const float grad_by = -grad_cross.x * a.z + grad_cross.z * a.x;
|
||||
const float grad_bz = grad_cross.x * a.y - grad_cross.y * a.x;
|
||||
const float3 grad_b = make_float3(grad_bx, grad_by, grad_bz);
|
||||
|
||||
return thrust::make_tuple(grad_a, grad_b);
|
||||
}
|
||||
|
||||
__device__ inline float norm(const float3& a) {
|
||||
return sqrt(dot(a, a));
|
||||
}
|
||||
|
||||
__device__ inline float3 normalize(const float3& a) {
|
||||
return a / (norm(a) + vEpsilon);
|
||||
}
|
||||
|
||||
__device__ inline float3 normalize_backward(
|
||||
const float3& a,
|
||||
const float3& grad_normz) {
|
||||
const float a_norm = norm(a) + vEpsilon;
|
||||
const float3 out = a / a_norm;
|
||||
|
||||
const float grad_ax = grad_normz.x * (1.0f - out.x * out.x) / a_norm +
|
||||
grad_normz.y * (-out.x * out.y) / a_norm +
|
||||
grad_normz.z * (-out.x * out.z) / a_norm;
|
||||
const float grad_ay = grad_normz.x * (-out.x * out.y) / a_norm +
|
||||
grad_normz.y * (1.0f - out.y * out.y) / a_norm +
|
||||
grad_normz.z * (-out.y * out.z) / a_norm;
|
||||
const float grad_az = grad_normz.x * (-out.x * out.z) / a_norm +
|
||||
grad_normz.y * (-out.y * out.z) / a_norm +
|
||||
grad_normz.z * (1.0f - out.z * out.z) / a_norm;
|
||||
return make_float3(grad_ax, grad_ay, grad_az);
|
||||
}
|
||||
651
pytorch3d/csrc/utils/geometry_utils.cuh
Normal file
651
pytorch3d/csrc/utils/geometry_utils.cuh
Normal file
@@ -0,0 +1,651 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <float.h>
|
||||
#include <math.h>
|
||||
#include <torch/extension.h>
|
||||
#include <cstdio>
|
||||
#include "float_math.cuh"
|
||||
|
||||
// Set epsilon for preventing floating point errors and division by 0.
|
||||
#ifdef _MSC_VER
|
||||
#define kEpsilon 1e-8f
|
||||
#else
|
||||
const auto kEpsilon = 1e-8;
|
||||
#endif
|
||||
|
||||
// ************************************************************* //
|
||||
// vec2 utils //
|
||||
// ************************************************************* //
|
||||
|
||||
// Determines whether a point p is on the right side of a 2D line segment
|
||||
// given by the end points v0, v1.
|
||||
//
|
||||
// Args:
|
||||
// p: vec2 Coordinates of a point.
|
||||
// v0, v1: vec2 Coordinates of the end points of the edge.
|
||||
//
|
||||
// Returns:
|
||||
// area: The signed area of the parallelogram given by the vectors
|
||||
// A = p - v0
|
||||
// B = v1 - v0
|
||||
//
|
||||
__device__ inline float
|
||||
EdgeFunctionForward(const float2& p, const float2& v0, const float2& v1) {
|
||||
return (p.x - v0.x) * (v1.y - v0.y) - (p.y - v0.y) * (v1.x - v0.x);
|
||||
}
|
||||
|
||||
// Backward pass for the edge function returning partial dervivatives for each
|
||||
// of the input points.
|
||||
//
|
||||
// Args:
|
||||
// p: vec2 Coordinates of a point.
|
||||
// v0, v1: vec2 Coordinates of the end points of the edge.
|
||||
// grad_edge: Upstream gradient for output from edge function.
|
||||
//
|
||||
// Returns:
|
||||
// tuple of gradients for each of the input points:
|
||||
// (float2 d_edge_dp, float2 d_edge_dv0, float2 d_edge_dv1)
|
||||
//
|
||||
__device__ inline thrust::tuple<float2, float2, float2> EdgeFunctionBackward(
|
||||
const float2& p,
|
||||
const float2& v0,
|
||||
const float2& v1,
|
||||
const float& grad_edge) {
|
||||
const float2 dedge_dp = make_float2(v1.y - v0.y, v0.x - v1.x);
|
||||
const float2 dedge_dv0 = make_float2(p.y - v1.y, v1.x - p.x);
|
||||
const float2 dedge_dv1 = make_float2(v0.y - p.y, p.x - v0.x);
|
||||
return thrust::make_tuple(
|
||||
grad_edge * dedge_dp, grad_edge * dedge_dv0, grad_edge * dedge_dv1);
|
||||
}
|
||||
|
||||
// The forward pass for computing the barycentric coordinates of a point
|
||||
// relative to a triangle.
|
||||
//
|
||||
// Args:
|
||||
// p: Coordinates of a point.
|
||||
// v0, v1, v2: Coordinates of the triangle vertices.
|
||||
//
|
||||
// Returns
|
||||
// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1].
|
||||
//
|
||||
__device__ inline float3 BarycentricCoordsForward(
|
||||
const float2& p,
|
||||
const float2& v0,
|
||||
const float2& v1,
|
||||
const float2& v2) {
|
||||
const float area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
|
||||
const float w0 = EdgeFunctionForward(p, v1, v2) / area;
|
||||
const float w1 = EdgeFunctionForward(p, v2, v0) / area;
|
||||
const float w2 = EdgeFunctionForward(p, v0, v1) / area;
|
||||
return make_float3(w0, w1, w2);
|
||||
}
|
||||
|
||||
// The backward pass for computing the barycentric coordinates of a point
|
||||
// relative to a triangle.
|
||||
//
|
||||
// Args:
|
||||
// p: Coordinates of a point.
|
||||
// v0, v1, v2: (x, y) coordinates of the triangle vertices.
|
||||
// grad_bary_upstream: vec3<T> Upstream gradient for each of the
|
||||
// barycentric coordaintes [grad_w0, grad_w1, grad_w2].
|
||||
//
|
||||
// Returns
|
||||
// tuple of gradients for each of the triangle vertices:
|
||||
// (float2 grad_v0, float2 grad_v1, float2 grad_v2)
|
||||
//
|
||||
__device__ inline thrust::tuple<float2, float2, float2, float2>
|
||||
BarycentricCoordsBackward(
|
||||
const float2& p,
|
||||
const float2& v0,
|
||||
const float2& v1,
|
||||
const float2& v2,
|
||||
const float3& grad_bary_upstream) {
|
||||
const float area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
|
||||
const float area2 = pow(area, 2.0f);
|
||||
const float e0 = EdgeFunctionForward(p, v1, v2);
|
||||
const float e1 = EdgeFunctionForward(p, v2, v0);
|
||||
const float e2 = EdgeFunctionForward(p, v0, v1);
|
||||
|
||||
const float grad_w0 = grad_bary_upstream.x;
|
||||
const float grad_w1 = grad_bary_upstream.y;
|
||||
const float grad_w2 = grad_bary_upstream.z;
|
||||
|
||||
// Calculate component of the gradient from each of w0, w1 and w2.
|
||||
// e.g. for w0:
|
||||
// dloss/dw0_v = dl/dw0 * dw0/dw0_top * dw0_top/dv
|
||||
// + dl/dw0 * dw0/dw0_bot * dw0_bot/dv
|
||||
const float dw0_darea = -e0 / (area2);
|
||||
const float dw0_e0 = 1 / area;
|
||||
const float dloss_d_w0area = grad_w0 * dw0_darea;
|
||||
const float dloss_e0 = grad_w0 * dw0_e0;
|
||||
auto de0_dv = EdgeFunctionBackward(p, v1, v2, dloss_e0);
|
||||
auto dw0area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w0area);
|
||||
const float2 dw0_p = thrust::get<0>(de0_dv);
|
||||
const float2 dw0_dv0 = thrust::get<1>(dw0area_dv);
|
||||
const float2 dw0_dv1 = thrust::get<1>(de0_dv) + thrust::get<2>(dw0area_dv);
|
||||
const float2 dw0_dv2 = thrust::get<2>(de0_dv) + thrust::get<0>(dw0area_dv);
|
||||
|
||||
const float dw1_darea = -e1 / (area2);
|
||||
const float dw1_e1 = 1 / area;
|
||||
const float dloss_d_w1area = grad_w1 * dw1_darea;
|
||||
const float dloss_e1 = grad_w1 * dw1_e1;
|
||||
auto de1_dv = EdgeFunctionBackward(p, v2, v0, dloss_e1);
|
||||
auto dw1area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w1area);
|
||||
const float2 dw1_p = thrust::get<0>(de1_dv);
|
||||
const float2 dw1_dv0 = thrust::get<2>(de1_dv) + thrust::get<1>(dw1area_dv);
|
||||
const float2 dw1_dv1 = thrust::get<2>(dw1area_dv);
|
||||
const float2 dw1_dv2 = thrust::get<1>(de1_dv) + thrust::get<0>(dw1area_dv);
|
||||
|
||||
const float dw2_darea = -e2 / (area2);
|
||||
const float dw2_e2 = 1 / area;
|
||||
const float dloss_d_w2area = grad_w2 * dw2_darea;
|
||||
const float dloss_e2 = grad_w2 * dw2_e2;
|
||||
auto de2_dv = EdgeFunctionBackward(p, v0, v1, dloss_e2);
|
||||
auto dw2area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w2area);
|
||||
const float2 dw2_p = thrust::get<0>(de2_dv);
|
||||
const float2 dw2_dv0 = thrust::get<1>(de2_dv) + thrust::get<1>(dw2area_dv);
|
||||
const float2 dw2_dv1 = thrust::get<2>(de2_dv) + thrust::get<2>(dw2area_dv);
|
||||
const float2 dw2_dv2 = thrust::get<0>(dw2area_dv);
|
||||
|
||||
const float2 dbary_p = dw0_p + dw1_p + dw2_p;
|
||||
const float2 dbary_dv0 = dw0_dv0 + dw1_dv0 + dw2_dv0;
|
||||
const float2 dbary_dv1 = dw0_dv1 + dw1_dv1 + dw2_dv1;
|
||||
const float2 dbary_dv2 = dw0_dv2 + dw1_dv2 + dw2_dv2;
|
||||
|
||||
return thrust::make_tuple(dbary_p, dbary_dv0, dbary_dv1, dbary_dv2);
|
||||
}
|
||||
|
||||
// Forward pass for applying perspective correction to barycentric coordinates.
|
||||
//
|
||||
// Args:
|
||||
// bary: Screen-space barycentric coordinates for a point
|
||||
// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
|
||||
//
|
||||
// Returns
|
||||
// World-space barycentric coordinates
|
||||
//
|
||||
__device__ inline float3 BarycentricPerspectiveCorrectionForward(
|
||||
const float3& bary,
|
||||
const float z0,
|
||||
const float z1,
|
||||
const float z2) {
|
||||
const float w0_top = bary.x * z1 * z2;
|
||||
const float w1_top = z0 * bary.y * z2;
|
||||
const float w2_top = z0 * z1 * bary.z;
|
||||
const float denom = w0_top + w1_top + w2_top;
|
||||
const float w0 = w0_top / denom;
|
||||
const float w1 = w1_top / denom;
|
||||
const float w2 = w2_top / denom;
|
||||
return make_float3(w0, w1, w2);
|
||||
}
|
||||
|
||||
// Backward pass for applying perspective correction to barycentric coordinates.
|
||||
//
|
||||
// Args:
|
||||
// bary: Screen-space barycentric coordinates for a point
|
||||
// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
|
||||
// grad_out: Upstream gradient of the loss with respect to the corrected
|
||||
// barycentric coordinates.
|
||||
//
|
||||
// Returns a tuple of:
|
||||
// grad_bary: Downstream gradient of the loss with respect to the the
|
||||
// uncorrected barycentric coordinates.
|
||||
// grad_z0, grad_z1, grad_z2: Downstream gradient of the loss with respect
|
||||
// to the z-coordinates of the triangle verts
|
||||
__device__ inline thrust::tuple<float3, float, float, float>
|
||||
BarycentricPerspectiveCorrectionBackward(
|
||||
const float3& bary,
|
||||
const float z0,
|
||||
const float z1,
|
||||
const float z2,
|
||||
const float3& grad_out) {
|
||||
// Recompute forward pass
|
||||
const float w0_top = bary.x * z1 * z2;
|
||||
const float w1_top = z0 * bary.y * z2;
|
||||
const float w2_top = z0 * z1 * bary.z;
|
||||
const float denom = w0_top + w1_top + w2_top;
|
||||
|
||||
// Now do backward pass
|
||||
const float grad_denom_top =
|
||||
-w0_top * grad_out.x - w1_top * grad_out.y - w2_top * grad_out.z;
|
||||
const float grad_denom = grad_denom_top / (denom * denom);
|
||||
const float grad_w0_top = grad_denom + grad_out.x / denom;
|
||||
const float grad_w1_top = grad_denom + grad_out.y / denom;
|
||||
const float grad_w2_top = grad_denom + grad_out.z / denom;
|
||||
const float grad_bary_x = grad_w0_top * z1 * z2;
|
||||
const float grad_bary_y = grad_w1_top * z0 * z2;
|
||||
const float grad_bary_z = grad_w2_top * z0 * z1;
|
||||
const float3 grad_bary = make_float3(grad_bary_x, grad_bary_y, grad_bary_z);
|
||||
const float grad_z0 = grad_w1_top * bary.y * z2 + grad_w2_top * bary.z * z1;
|
||||
const float grad_z1 = grad_w0_top * bary.x * z2 + grad_w2_top * bary.z * z0;
|
||||
const float grad_z2 = grad_w0_top * bary.x * z1 + grad_w1_top * bary.y * z0;
|
||||
return thrust::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
|
||||
}
|
||||
|
||||
// Calculate minimum squared distance between a line segment (v1 - v0) and a
|
||||
// point p.
|
||||
//
|
||||
// Args:
|
||||
// p: Coordinates of a point.
|
||||
// v0, v1: Coordinates of the end points of the line segment.
|
||||
//
|
||||
// Returns:
|
||||
// squared distance to the boundary of the triangle.
|
||||
//
|
||||
__device__ inline float
|
||||
PointLineDistanceForward(const float2& p, const float2& a, const float2& b) {
|
||||
const float2 ba = b - a;
|
||||
float l2 = dot(ba, ba);
|
||||
float t = dot(ba, p - a) / l2;
|
||||
if (l2 <= kEpsilon) {
|
||||
return dot(p - b, p - b);
|
||||
}
|
||||
t = __saturatef(t); // clamp to the interval [+0.0, 1.0]
|
||||
const float2 p_proj = a + t * ba;
|
||||
const float2 d = (p_proj - p);
|
||||
return dot(d, d); // squared distance
|
||||
}
|
||||
|
||||
// Backward pass for point to line distance in 2D.
|
||||
//
|
||||
// Args:
|
||||
// p: Coordinates of a point.
|
||||
// v0, v1: Coordinates of the end points of the line segment.
|
||||
// grad_dist: Upstream gradient for the distance.
|
||||
//
|
||||
// Returns:
|
||||
// tuple of gradients for each of the input points:
|
||||
// (float2 grad_p, float2 grad_v0, float2 grad_v1)
|
||||
//
|
||||
__device__ inline thrust::tuple<float2, float2, float2>
|
||||
PointLineDistanceBackward(
|
||||
const float2& p,
|
||||
const float2& v0,
|
||||
const float2& v1,
|
||||
const float& grad_dist) {
|
||||
// Redo some of the forward pass calculations.
|
||||
const float2 v1v0 = v1 - v0;
|
||||
const float2 pv0 = p - v0;
|
||||
const float t_bot = dot(v1v0, v1v0);
|
||||
const float t_top = dot(v1v0, pv0);
|
||||
float tt = t_top / t_bot;
|
||||
tt = __saturatef(tt);
|
||||
const float2 p_proj = (1.0f - tt) * v0 + tt * v1;
|
||||
const float2 d = p - p_proj;
|
||||
const float dist = sqrt(dot(d, d));
|
||||
|
||||
const float2 grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
|
||||
const float2 grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);
|
||||
const float2 grad_v1 = grad_dist * tt * 2.0f * (p_proj - p);
|
||||
|
||||
return thrust::make_tuple(grad_p, grad_v0, grad_v1);
|
||||
}
|
||||
|
||||
// The forward pass for calculating the shortest distance between a point
|
||||
// and a triangle.
|
||||
//
|
||||
// Args:
|
||||
// p: Coordinates of a point.
|
||||
// v0, v1, v2: Coordinates of the three triangle vertices.
|
||||
//
|
||||
// Returns:
|
||||
// shortest squared distance from a point to a triangle.
|
||||
//
|
||||
__device__ inline float PointTriangleDistanceForward(
|
||||
const float2& p,
|
||||
const float2& v0,
|
||||
const float2& v1,
|
||||
const float2& v2) {
|
||||
// Compute distance to all 3 edges of the triangle and return the min.
|
||||
const float e01_dist = PointLineDistanceForward(p, v0, v1);
|
||||
const float e02_dist = PointLineDistanceForward(p, v0, v2);
|
||||
const float e12_dist = PointLineDistanceForward(p, v1, v2);
|
||||
const float edge_dist = fminf(fminf(e01_dist, e02_dist), e12_dist);
|
||||
return edge_dist;
|
||||
}
|
||||
|
||||
// Backward pass for point triangle distance.
|
||||
//
|
||||
// Args:
|
||||
// p: Coordinates of a point.
|
||||
// v0, v1, v2: Coordinates of the three triangle vertices.
|
||||
// grad_dist: Upstream gradient for the distance.
|
||||
//
|
||||
// Returns:
|
||||
// tuple of gradients for each of the triangle vertices:
|
||||
// (float2 grad_v0, float2 grad_v1, float2 grad_v2)
|
||||
//
|
||||
__device__ inline thrust::tuple<float2, float2, float2, float2>
|
||||
PointTriangleDistanceBackward(
|
||||
const float2& p,
|
||||
const float2& v0,
|
||||
const float2& v1,
|
||||
const float2& v2,
|
||||
const float& grad_dist) {
|
||||
// Compute distance to all 3 edges of the triangle.
|
||||
const float e01_dist = PointLineDistanceForward(p, v0, v1);
|
||||
const float e02_dist = PointLineDistanceForward(p, v0, v2);
|
||||
const float e12_dist = PointLineDistanceForward(p, v1, v2);
|
||||
|
||||
// Initialize output tensors.
|
||||
float2 grad_v0 = make_float2(0.0f, 0.0f);
|
||||
float2 grad_v1 = make_float2(0.0f, 0.0f);
|
||||
float2 grad_v2 = make_float2(0.0f, 0.0f);
|
||||
float2 grad_p = make_float2(0.0f, 0.0f);
|
||||
|
||||
// Find which edge is the closest and return PointLineDistanceBackward for
|
||||
// that edge.
|
||||
if (e01_dist <= e02_dist && e01_dist <= e12_dist) {
|
||||
// Closest edge is v1 - v0.
|
||||
auto grad_e01 = PointLineDistanceBackward(p, v0, v1, grad_dist);
|
||||
grad_p = thrust::get<0>(grad_e01);
|
||||
grad_v0 = thrust::get<1>(grad_e01);
|
||||
grad_v1 = thrust::get<2>(grad_e01);
|
||||
} else if (e02_dist <= e01_dist && e02_dist <= e12_dist) {
|
||||
// Closest edge is v2 - v0.
|
||||
auto grad_e02 = PointLineDistanceBackward(p, v0, v2, grad_dist);
|
||||
grad_p = thrust::get<0>(grad_e02);
|
||||
grad_v0 = thrust::get<1>(grad_e02);
|
||||
grad_v2 = thrust::get<2>(grad_e02);
|
||||
} else if (e12_dist <= e01_dist && e12_dist <= e02_dist) {
|
||||
// Closest edge is v2 - v1.
|
||||
auto grad_e12 = PointLineDistanceBackward(p, v1, v2, grad_dist);
|
||||
grad_p = thrust::get<0>(grad_e12);
|
||||
grad_v1 = thrust::get<1>(grad_e12);
|
||||
grad_v2 = thrust::get<2>(grad_e12);
|
||||
}
|
||||
|
||||
return thrust::make_tuple(grad_p, grad_v0, grad_v1, grad_v2);
|
||||
}
|
||||
|
||||
// ************************************************************* //
|
||||
// vec3 utils //
|
||||
// ************************************************************* //
|
||||
|
||||
// Computes the barycentric coordinates of a point p relative
|
||||
// to a triangle (v0, v1, v2), i.e. p = w0 * v0 + w1 * v1 + w2 * v2
|
||||
// s.t. w0 + w1 + w2 = 1.0
|
||||
//
|
||||
// NOTE that this function assumes that p lives on the space spanned
|
||||
// by (v0, v1, v2).
|
||||
// TODO(gkioxari) explicitly check whether p is coplanar with (v0, v1, v2)
|
||||
// and throw an error if check fails
|
||||
//
|
||||
// Args:
|
||||
// p: vec3 coordinates of a point
|
||||
// v0, v1, v2: vec3 coordinates of the triangle vertices
|
||||
//
|
||||
// Returns
|
||||
// bary: (w0, w1, w2) barycentric coordinates
|
||||
//
|
||||
__device__ inline float3 BarycentricCoords3Forward(
|
||||
const float3& p,
|
||||
const float3& v0,
|
||||
const float3& v1,
|
||||
const float3& v2) {
|
||||
float3 p0 = v1 - v0;
|
||||
float3 p1 = v2 - v0;
|
||||
float3 p2 = p - v0;
|
||||
|
||||
const float d00 = dot(p0, p0);
|
||||
const float d01 = dot(p0, p1);
|
||||
const float d11 = dot(p1, p1);
|
||||
const float d20 = dot(p2, p0);
|
||||
const float d21 = dot(p2, p1);
|
||||
|
||||
const float denom = d00 * d11 - d01 * d01 + kEpsilon;
|
||||
const float w1 = (d11 * d20 - d01 * d21) / denom;
|
||||
const float w2 = (d00 * d21 - d01 * d20) / denom;
|
||||
const float w0 = 1.0f - w1 - w2;
|
||||
|
||||
return make_float3(w0, w1, w2);
|
||||
}
|
||||
|
||||
// Checks whether the point p is inside the triangle (v0, v1, v2).
|
||||
// A point is inside the triangle, if all barycentric coordinates
|
||||
// wrt the triangle are >= 0 & <= 1.
|
||||
//
|
||||
// NOTE that this function assumes that p lives on the space spanned
|
||||
// by (v0, v1, v2).
|
||||
// TODO(gkioxari) explicitly check whether p is coplanar with (v0, v1, v2)
|
||||
// and throw an error if check fails
|
||||
//
|
||||
// Args:
|
||||
// p: vec3 coordinates of a point
|
||||
// v0, v1, v2: vec3 coordinates of the triangle vertices
|
||||
//
|
||||
// Returns:
|
||||
// inside: bool indicating wether p is inside triangle
|
||||
//
|
||||
__device__ inline bool IsInsideTriangle(
|
||||
const float3& p,
|
||||
const float3& v0,
|
||||
const float3& v1,
|
||||
const float3& v2) {
|
||||
float3 bary = BarycentricCoords3Forward(p, v0, v1, v2);
|
||||
bool x_in = 0.0f <= bary.x && bary.x <= 1.0f;
|
||||
bool y_in = 0.0f <= bary.y && bary.y <= 1.0f;
|
||||
bool z_in = 0.0f <= bary.z && bary.z <= 1.0f;
|
||||
bool inside = x_in && y_in && z_in;
|
||||
return inside;
|
||||
}
|
||||
|
||||
// Computes the minimum squared Euclidean distance between the point p
|
||||
// and the segment spanned by (v0, v1).
|
||||
// To find this we parametrize p as: x(t) = v0 + t * (v1 - v0)
|
||||
// and find t which minimizes (x(t) - p) ^ 2.
|
||||
// Note that p does not need to live in the space spanned by (v0, v1)
|
||||
//
|
||||
// Args:
|
||||
// p: vec3 coordinates of a point
|
||||
// v0, v1: vec3 coordinates of start and end of segment
|
||||
//
|
||||
// Returns:
|
||||
// dist: the minimum squared distance of p from segment (v0, v1)
|
||||
//
|
||||
|
||||
__device__ inline float
|
||||
PointLine3DistanceForward(const float3& p, const float3& v0, const float3& v1) {
|
||||
const float3 v1v0 = v1 - v0;
|
||||
const float3 pv0 = p - v0;
|
||||
const float t_bot = dot(v1v0, v1v0);
|
||||
const float t_top = dot(pv0, v1v0);
|
||||
// if t_bot small, then v0 == v1, set tt to 0.
|
||||
float tt = (t_bot < kEpsilon) ? 0.0f : (t_top / t_bot);
|
||||
|
||||
tt = __saturatef(tt); // clamps to [0, 1]
|
||||
|
||||
const float3 p_proj = v0 + tt * v1v0;
|
||||
const float3 diff = p - p_proj;
|
||||
const float dist = dot(diff, diff);
|
||||
return dist;
|
||||
}
|
||||
|
||||
// Backward function of the minimum squared Euclidean distance between the point
|
||||
// p and the line segment (v0, v1).
|
||||
//
|
||||
// Args:
|
||||
// p: vec3 coordinates of a point
|
||||
// v0, v1: vec3 coordinates of start and end of segment
|
||||
// grad_dist: Float of the gradient wrt dist
|
||||
//
|
||||
// Returns:
|
||||
// tuple of gradients for the point and line segment (v0, v1):
|
||||
// (float3 grad_p, float3 grad_v0, float3 grad_v1)
|
||||
|
||||
__device__ inline thrust::tuple<float3, float3, float3>
|
||||
PointLine3DistanceBackward(
|
||||
const float3& p,
|
||||
const float3& v0,
|
||||
const float3& v1,
|
||||
const float& grad_dist) {
|
||||
const float3 v1v0 = v1 - v0;
|
||||
const float3 pv0 = p - v0;
|
||||
const float t_bot = dot(v1v0, v1v0);
|
||||
const float t_top = dot(v1v0, pv0);
|
||||
|
||||
float3 grad_p = make_float3(0.0f, 0.0f, 0.0f);
|
||||
float3 grad_v0 = make_float3(0.0f, 0.0f, 0.0f);
|
||||
float3 grad_v1 = make_float3(0.0f, 0.0f, 0.0f);
|
||||
|
||||
const float tt = t_top / t_bot;
|
||||
|
||||
if (t_bot < kEpsilon) {
|
||||
// if t_bot small, then v0 == v1,
|
||||
// and dist = 0.5 * dot(pv0, pv0) + 0.5 * dot(pv1, pv1)
|
||||
grad_p = grad_dist * 2.0f * pv0;
|
||||
grad_v0 = -0.5f * grad_p;
|
||||
grad_v1 = grad_v0;
|
||||
} else if (tt < 0.0f) {
|
||||
grad_p = grad_dist * 2.0f * pv0;
|
||||
grad_v0 = -1.0f * grad_p;
|
||||
// no gradients wrt v1
|
||||
} else if (tt > 1.0f) {
|
||||
grad_p = grad_dist * 2.0f * (p - v1);
|
||||
grad_v1 = -1.0f * grad_p;
|
||||
// no gradients wrt v0
|
||||
} else {
|
||||
const float3 p_proj = v0 + tt * v1v0;
|
||||
const float3 diff = p - p_proj;
|
||||
const float3 grad_base = grad_dist * 2.0f * diff;
|
||||
grad_p = grad_base - dot(grad_base, v1v0) * v1v0 / t_bot;
|
||||
const float3 dtt_v0 = (-1.0f * v1v0 - pv0 + 2.0f * tt * v1v0) / t_bot;
|
||||
grad_v0 = (-1.0f + tt) * grad_base - dot(grad_base, v1v0) * dtt_v0;
|
||||
const float3 dtt_v1 = (pv0 - 2.0f * tt * v1v0) / t_bot;
|
||||
grad_v1 = -dot(grad_base, v1v0) * dtt_v1 - tt * grad_base;
|
||||
}
|
||||
|
||||
return thrust::make_tuple(grad_p, grad_v0, grad_v1);
|
||||
}
|
||||
|
||||
// Computes the squared distance of a point p relative to a triangle (v0, v1,
|
||||
// v2). If the point's projection p0 on the plane spanned by (v0, v1, v2) is
|
||||
// inside the triangle with vertices (v0, v1, v2), then the returned value is
|
||||
// the squared distance of p to its projection p0. Otherwise, the returned value
|
||||
// is the smallest squared distance of p from the line segments (v0, v1), (v0,
|
||||
// v2) and (v1, v2).
|
||||
//
|
||||
// Args:
|
||||
// p: vec3 coordinates of a point
|
||||
// v0, v1, v2: vec3 coordinates of the triangle vertices
|
||||
//
|
||||
// Returns:
|
||||
// dist: Float of the squared distance
|
||||
//
|
||||
|
||||
__device__ inline float PointTriangle3DistanceForward(
|
||||
const float3& p,
|
||||
const float3& v0,
|
||||
const float3& v1,
|
||||
const float3& v2) {
|
||||
float3 normal = cross(v2 - v0, v1 - v0);
|
||||
const float norm_normal = norm(normal);
|
||||
normal = normalize(normal);
|
||||
|
||||
// p0 is the projection of p on the plane spanned by (v0, v1, v2)
|
||||
// i.e. p0 = p + t * normal, s.t. (p0 - v0) is orthogonal to normal
|
||||
const float t = dot(v0 - p, normal);
|
||||
const float3 p0 = p + t * normal;
|
||||
|
||||
bool is_inside = IsInsideTriangle(p0, v0, v1, v2);
|
||||
float dist = 0.0f;
|
||||
|
||||
if ((is_inside) && (norm_normal > kEpsilon)) {
|
||||
// if projection p0 is inside triangle spanned by (v0, v1, v2)
|
||||
// then distance is equal to norm(p0 - p)^2
|
||||
dist = t * t;
|
||||
} else {
|
||||
const float e01 = PointLine3DistanceForward(p, v0, v1);
|
||||
const float e02 = PointLine3DistanceForward(p, v0, v2);
|
||||
const float e12 = PointLine3DistanceForward(p, v1, v2);
|
||||
|
||||
dist = (e01 > e02) ? e02 : e01;
|
||||
dist = (dist > e12) ? e12 : dist;
|
||||
}
|
||||
|
||||
return dist;
|
||||
}
|
||||
|
||||
// The backward pass for computing the squared distance of a point
|
||||
// to the triangle (v0, v1, v2).
|
||||
//
|
||||
// Args:
|
||||
// p: xyz coordinates of a point
|
||||
// v0, v1, v2: xyz coordinates of the triangle vertices
|
||||
// grad_dist: Float of the gradient wrt dist
|
||||
//
|
||||
// Returns:
|
||||
// tuple of gradients for the point and triangle:
|
||||
// (float3 grad_p, float3 grad_v0, float3 grad_v1, float3 grad_v2)
|
||||
//
|
||||
|
||||
__device__ inline thrust::tuple<float3, float3, float3, float3>
|
||||
PointTriangle3DistanceBackward(
|
||||
const float3& p,
|
||||
const float3& v0,
|
||||
const float3& v1,
|
||||
const float3& v2,
|
||||
const float& grad_dist) {
|
||||
const float3 v2v0 = v2 - v0;
|
||||
const float3 v1v0 = v1 - v0;
|
||||
const float3 v0p = v0 - p;
|
||||
float3 raw_normal = cross(v2v0, v1v0);
|
||||
const float norm_normal = norm(raw_normal);
|
||||
float3 normal = normalize(raw_normal);
|
||||
|
||||
// p0 is the projection of p on the plane spanned by (v0, v1, v2)
|
||||
// i.e. p0 = p + t * normal, s.t. (p0 - v0) is orthogonal to normal
|
||||
const float t = dot(v0 - p, normal);
|
||||
const float3 p0 = p + t * normal;
|
||||
const float3 diff = t * normal;
|
||||
|
||||
bool is_inside = IsInsideTriangle(p0, v0, v1, v2);
|
||||
|
||||
float3 grad_p = make_float3(0.0f, 0.0f, 0.0f);
|
||||
float3 grad_v0 = make_float3(0.0f, 0.0f, 0.0f);
|
||||
float3 grad_v1 = make_float3(0.0f, 0.0f, 0.0f);
|
||||
float3 grad_v2 = make_float3(0.0f, 0.0f, 0.0f);
|
||||
|
||||
if ((is_inside) && (norm_normal > kEpsilon)) {
|
||||
// derivative of dist wrt p
|
||||
grad_p = -2.0f * grad_dist * t * normal;
|
||||
// derivative of dist wrt normal
|
||||
const float3 grad_normal = 2.0f * grad_dist * t * (v0p + diff);
|
||||
// derivative of dist wrt raw_normal
|
||||
const float3 grad_raw_normal = normalize_backward(raw_normal, grad_normal);
|
||||
// derivative of dist wrt v2v0 and v1v0
|
||||
const auto grad_cross = cross_backward(v2v0, v1v0, grad_raw_normal);
|
||||
const float3 grad_cross_v2v0 = thrust::get<0>(grad_cross);
|
||||
const float3 grad_cross_v1v0 = thrust::get<1>(grad_cross);
|
||||
grad_v0 =
|
||||
grad_dist * 2.0f * t * normal - (grad_cross_v2v0 + grad_cross_v1v0);
|
||||
grad_v1 = grad_cross_v1v0;
|
||||
grad_v2 = grad_cross_v2v0;
|
||||
} else {
|
||||
const float e01 = PointLine3DistanceForward(p, v0, v1);
|
||||
const float e02 = PointLine3DistanceForward(p, v0, v2);
|
||||
const float e12 = PointLine3DistanceForward(p, v1, v2);
|
||||
|
||||
if ((e01 <= e02) && (e01 <= e12)) {
|
||||
// e01 is smallest
|
||||
const auto grads = PointLine3DistanceBackward(p, v0, v1, grad_dist);
|
||||
grad_p = thrust::get<0>(grads);
|
||||
grad_v0 = thrust::get<1>(grads);
|
||||
grad_v1 = thrust::get<2>(grads);
|
||||
} else if ((e02 <= e01) && (e02 <= e12)) {
|
||||
// e02 is smallest
|
||||
const auto grads = PointLine3DistanceBackward(p, v0, v2, grad_dist);
|
||||
grad_p = thrust::get<0>(grads);
|
||||
grad_v0 = thrust::get<1>(grads);
|
||||
grad_v2 = thrust::get<2>(grads);
|
||||
} else if ((e12 <= e01) && (e12 <= e02)) {
|
||||
// e12 is smallest
|
||||
const auto grads = PointLine3DistanceBackward(p, v1, v2, grad_dist);
|
||||
grad_p = thrust::get<0>(grads);
|
||||
grad_v1 = thrust::get<1>(grads);
|
||||
grad_v2 = thrust::get<2>(grads);
|
||||
}
|
||||
}
|
||||
|
||||
return thrust::make_tuple(grad_p, grad_v0, grad_v1, grad_v2);
|
||||
}
|
||||
398
pytorch3d/csrc/utils/geometry_utils.h
Normal file
398
pytorch3d/csrc/utils/geometry_utils.h
Normal file
@@ -0,0 +1,398 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <algorithm>
|
||||
#include <type_traits>
|
||||
#include "vec2.h"
|
||||
#include "vec3.h"
|
||||
|
||||
// Set epsilon for preventing floating point errors and division by 0.
|
||||
const auto kEpsilon = 1e-8;
|
||||
|
||||
// Determines whether a point p is on the right side of a 2D line segment
|
||||
// given by the end points v0, v1.
|
||||
//
|
||||
// Args:
|
||||
// p: vec2 Coordinates of a point.
|
||||
// v0, v1: vec2 Coordinates of the end points of the edge.
|
||||
//
|
||||
// Returns:
|
||||
// area: The signed area of the parallelogram given by the vectors
|
||||
// A = p - v0
|
||||
// B = v1 - v0
|
||||
//
|
||||
// v1 ________
|
||||
// /\ /
|
||||
// A / \ /
|
||||
// / \ /
|
||||
// v0 /______\/
|
||||
// B p
|
||||
//
|
||||
// The area can also be interpreted as the cross product A x B.
|
||||
// If the sign of the area is positive, the point p is on the
|
||||
// right side of the edge. Negative area indicates the point is on
|
||||
// the left side of the edge. i.e. for an edge v1 - v0:
|
||||
//
|
||||
// v1
|
||||
// /
|
||||
// /
|
||||
// - / +
|
||||
// /
|
||||
// /
|
||||
// v0
|
||||
//
|
||||
template <typename T>
|
||||
T EdgeFunctionForward(const vec2<T>& p, const vec2<T>& v0, const vec2<T>& v1) {
|
||||
const T edge = (p.x - v0.x) * (v1.y - v0.y) - (p.y - v0.y) * (v1.x - v0.x);
|
||||
return edge;
|
||||
}
|
||||
|
||||
// Backward pass for the edge function returning partial dervivatives for each
|
||||
// of the input points.
|
||||
//
|
||||
// Args:
|
||||
// p: vec2 Coordinates of a point.
|
||||
// v0, v1: vec2 Coordinates of the end points of the edge.
|
||||
// grad_edge: Upstream gradient for output from edge function.
|
||||
//
|
||||
// Returns:
|
||||
// tuple of gradients for each of the input points:
|
||||
// (vec2<T> d_edge_dp, vec2<T> d_edge_dv0, vec2<T> d_edge_dv1)
|
||||
//
|
||||
template <typename T>
|
||||
inline std::tuple<vec2<T>, vec2<T>, vec2<T>> EdgeFunctionBackward(
|
||||
const vec2<T>& p,
|
||||
const vec2<T>& v0,
|
||||
const vec2<T>& v1,
|
||||
const T grad_edge) {
|
||||
const vec2<T> dedge_dp(v1.y - v0.y, v0.x - v1.x);
|
||||
const vec2<T> dedge_dv0(p.y - v1.y, v1.x - p.x);
|
||||
const vec2<T> dedge_dv1(v0.y - p.y, p.x - v0.x);
|
||||
return std::make_tuple(
|
||||
grad_edge * dedge_dp, grad_edge * dedge_dv0, grad_edge * dedge_dv1);
|
||||
}
|
||||
|
||||
// The forward pass for computing the barycentric coordinates of a point
|
||||
// relative to a triangle.
|
||||
// Ref:
|
||||
// https://www.scratchapixel.com/lessons/3d-basic-rendering/ray-tracing-rendering-a-triangle/barycentric-coordinates
|
||||
//
|
||||
// Args:
|
||||
// p: Coordinates of a point.
|
||||
// v0, v1, v2: Coordinates of the triangle vertices.
|
||||
//
|
||||
// Returns
|
||||
// bary: (w0, w1, w2) barycentric coordinates in the range [0, 1].
|
||||
//
|
||||
template <typename T>
|
||||
vec3<T> BarycentricCoordinatesForward(
|
||||
const vec2<T>& p,
|
||||
const vec2<T>& v0,
|
||||
const vec2<T>& v1,
|
||||
const vec2<T>& v2) {
|
||||
const T area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
|
||||
const T w0 = EdgeFunctionForward(p, v1, v2) / area;
|
||||
const T w1 = EdgeFunctionForward(p, v2, v0) / area;
|
||||
const T w2 = EdgeFunctionForward(p, v0, v1) / area;
|
||||
return vec3<T>(w0, w1, w2);
|
||||
}
|
||||
|
||||
// The backward pass for computing the barycentric coordinates of a point
|
||||
// relative to a triangle.
|
||||
//
|
||||
// Args:
|
||||
// p: Coordinates of a point.
|
||||
// v0, v1, v2: (x, y) coordinates of the triangle vertices.
|
||||
// grad_bary_upstream: vec3<T> Upstream gradient for each of the
|
||||
// barycentric coordaintes [grad_w0, grad_w1, grad_w2].
|
||||
//
|
||||
// Returns
|
||||
// tuple of gradients for each of the triangle vertices:
|
||||
// (vec2<T> grad_v0, vec2<T> grad_v1, vec2<T> grad_v2)
|
||||
//
|
||||
template <typename T>
|
||||
inline std::tuple<vec2<T>, vec2<T>, vec2<T>, vec2<T>> BarycentricCoordsBackward(
|
||||
const vec2<T>& p,
|
||||
const vec2<T>& v0,
|
||||
const vec2<T>& v1,
|
||||
const vec2<T>& v2,
|
||||
const vec3<T>& grad_bary_upstream) {
|
||||
const T area = EdgeFunctionForward(v2, v0, v1) + kEpsilon;
|
||||
const T area2 = pow(area, 2.0f);
|
||||
const T area_inv = 1.0f / area;
|
||||
const T e0 = EdgeFunctionForward(p, v1, v2);
|
||||
const T e1 = EdgeFunctionForward(p, v2, v0);
|
||||
const T e2 = EdgeFunctionForward(p, v0, v1);
|
||||
|
||||
const T grad_w0 = grad_bary_upstream.x;
|
||||
const T grad_w1 = grad_bary_upstream.y;
|
||||
const T grad_w2 = grad_bary_upstream.z;
|
||||
|
||||
// Calculate component of the gradient from each of w0, w1 and w2.
|
||||
// e.g. for w0:
|
||||
// dloss/dw0_v = dl/dw0 * dw0/dw0_top * dw0_top/dv
|
||||
// + dl/dw0 * dw0/dw0_bot * dw0_bot/dv
|
||||
const T dw0_darea = -e0 / (area2);
|
||||
const T dw0_e0 = area_inv;
|
||||
const T dloss_d_w0area = grad_w0 * dw0_darea;
|
||||
const T dloss_e0 = grad_w0 * dw0_e0;
|
||||
auto de0_dv = EdgeFunctionBackward(p, v1, v2, dloss_e0);
|
||||
auto dw0area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w0area);
|
||||
const vec2<T> dw0_p = std::get<0>(de0_dv);
|
||||
const vec2<T> dw0_dv0 = std::get<1>(dw0area_dv);
|
||||
const vec2<T> dw0_dv1 = std::get<1>(de0_dv) + std::get<2>(dw0area_dv);
|
||||
const vec2<T> dw0_dv2 = std::get<2>(de0_dv) + std::get<0>(dw0area_dv);
|
||||
|
||||
const T dw1_darea = -e1 / (area2);
|
||||
const T dw1_e1 = area_inv;
|
||||
const T dloss_d_w1area = grad_w1 * dw1_darea;
|
||||
const T dloss_e1 = grad_w1 * dw1_e1;
|
||||
auto de1_dv = EdgeFunctionBackward(p, v2, v0, dloss_e1);
|
||||
auto dw1area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w1area);
|
||||
const vec2<T> dw1_p = std::get<0>(de1_dv);
|
||||
const vec2<T> dw1_dv0 = std::get<2>(de1_dv) + std::get<1>(dw1area_dv);
|
||||
const vec2<T> dw1_dv1 = std::get<2>(dw1area_dv);
|
||||
const vec2<T> dw1_dv2 = std::get<1>(de1_dv) + std::get<0>(dw1area_dv);
|
||||
|
||||
const T dw2_darea = -e2 / (area2);
|
||||
const T dw2_e2 = area_inv;
|
||||
const T dloss_d_w2area = grad_w2 * dw2_darea;
|
||||
const T dloss_e2 = grad_w2 * dw2_e2;
|
||||
auto de2_dv = EdgeFunctionBackward(p, v0, v1, dloss_e2);
|
||||
auto dw2area_dv = EdgeFunctionBackward(v2, v0, v1, dloss_d_w2area);
|
||||
const vec2<T> dw2_p = std::get<0>(de2_dv);
|
||||
const vec2<T> dw2_dv0 = std::get<1>(de2_dv) + std::get<1>(dw2area_dv);
|
||||
const vec2<T> dw2_dv1 = std::get<2>(de2_dv) + std::get<2>(dw2area_dv);
|
||||
const vec2<T> dw2_dv2 = std::get<0>(dw2area_dv);
|
||||
|
||||
const vec2<T> dbary_p = dw0_p + dw1_p + dw2_p;
|
||||
const vec2<T> dbary_dv0 = dw0_dv0 + dw1_dv0 + dw2_dv0;
|
||||
const vec2<T> dbary_dv1 = dw0_dv1 + dw1_dv1 + dw2_dv1;
|
||||
const vec2<T> dbary_dv2 = dw0_dv2 + dw1_dv2 + dw2_dv2;
|
||||
|
||||
return std::make_tuple(dbary_p, dbary_dv0, dbary_dv1, dbary_dv2);
|
||||
}
|
||||
|
||||
// Forward pass for applying perspective correction to barycentric coordinates.
|
||||
//
|
||||
// Args:
|
||||
// bary: Screen-space barycentric coordinates for a point
|
||||
// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
|
||||
//
|
||||
// Returns
|
||||
// World-space barycentric coordinates
|
||||
//
|
||||
template <typename T>
|
||||
inline vec3<T> BarycentricPerspectiveCorrectionForward(
|
||||
const vec3<T>& bary,
|
||||
const T z0,
|
||||
const T z1,
|
||||
const T z2) {
|
||||
const T w0_top = bary.x * z1 * z2;
|
||||
const T w1_top = bary.y * z0 * z2;
|
||||
const T w2_top = bary.z * z0 * z1;
|
||||
const T denom = w0_top + w1_top + w2_top;
|
||||
const T w0 = w0_top / denom;
|
||||
const T w1 = w1_top / denom;
|
||||
const T w2 = w2_top / denom;
|
||||
return vec3<T>(w0, w1, w2);
|
||||
}
|
||||
|
||||
// Backward pass for applying perspective correction to barycentric coordinates.
|
||||
//
|
||||
// Args:
|
||||
// bary: Screen-space barycentric coordinates for a point
|
||||
// z0, z1, z2: Camera-space z-coordinates of the triangle vertices
|
||||
// grad_out: Upstream gradient of the loss with respect to the corrected
|
||||
// barycentric coordinates.
|
||||
//
|
||||
// Returns a tuple of:
|
||||
// grad_bary: Downstream gradient of the loss with respect to the the
|
||||
// uncorrected barycentric coordinates.
|
||||
// grad_z0, grad_z1, grad_z2: Downstream gradient of the loss with respect
|
||||
// to the z-coordinates of the triangle verts
|
||||
template <typename T>
|
||||
inline std::tuple<vec3<T>, T, T, T> BarycentricPerspectiveCorrectionBackward(
|
||||
const vec3<T>& bary,
|
||||
const T z0,
|
||||
const T z1,
|
||||
const T z2,
|
||||
const vec3<T>& grad_out) {
|
||||
// Recompute forward pass
|
||||
const T w0_top = bary.x * z1 * z2;
|
||||
const T w1_top = bary.y * z0 * z2;
|
||||
const T w2_top = bary.z * z0 * z1;
|
||||
const T denom = w0_top + w1_top + w2_top;
|
||||
|
||||
// Now do backward pass
|
||||
const T grad_denom_top =
|
||||
-w0_top * grad_out.x - w1_top * grad_out.y - w2_top * grad_out.z;
|
||||
const T grad_denom = grad_denom_top / (denom * denom);
|
||||
const T grad_w0_top = grad_denom + grad_out.x / denom;
|
||||
const T grad_w1_top = grad_denom + grad_out.y / denom;
|
||||
const T grad_w2_top = grad_denom + grad_out.z / denom;
|
||||
const T grad_bary_x = grad_w0_top * z1 * z2;
|
||||
const T grad_bary_y = grad_w1_top * z0 * z2;
|
||||
const T grad_bary_z = grad_w2_top * z0 * z1;
|
||||
const vec3<T> grad_bary(grad_bary_x, grad_bary_y, grad_bary_z);
|
||||
const T grad_z0 = grad_w1_top * bary.y * z2 + grad_w2_top * bary.z * z1;
|
||||
const T grad_z1 = grad_w0_top * bary.x * z2 + grad_w2_top * bary.z * z0;
|
||||
const T grad_z2 = grad_w0_top * bary.x * z1 + grad_w1_top * bary.y * z0;
|
||||
return std::make_tuple(grad_bary, grad_z0, grad_z1, grad_z2);
|
||||
}
|
||||
|
||||
// Calculate minimum squared distance between a line segment (v1 - v0) and a
|
||||
// point p.
|
||||
//
|
||||
// Args:
|
||||
// p: Coordinates of a point.
|
||||
// v0, v1: Coordinates of the end points of the line segment.
|
||||
//
|
||||
// Returns:
|
||||
// squared distance of the point to the line.
|
||||
//
|
||||
// Consider the line extending the segment - this can be parameterized as:
|
||||
// v0 + t (v1 - v0).
|
||||
//
|
||||
// First find the projection of point p onto the line. It falls where:
|
||||
// t = [(p - v0) . (v1 - v0)] / |v1 - v0|^2
|
||||
// where . is the dot product.
|
||||
//
|
||||
// The parameter t is clamped from [0, 1] to handle points outside the
|
||||
// segment (v1 - v0).
|
||||
//
|
||||
// Once the projection of the point on the segment is known, the distance from
|
||||
// p to the projection gives the minimum distance to the segment.
|
||||
//
|
||||
template <typename T>
|
||||
T PointLineDistanceForward(
|
||||
const vec2<T>& p,
|
||||
const vec2<T>& v0,
|
||||
const vec2<T>& v1) {
|
||||
const vec2<T> v1v0 = v1 - v0;
|
||||
const T l2 = dot(v1v0, v1v0);
|
||||
if (l2 <= kEpsilon) {
|
||||
return dot(p - v1, p - v1);
|
||||
}
|
||||
|
||||
const T t = dot(v1v0, p - v0) / l2;
|
||||
const T tt = std::min(std::max(t, 0.00f), 1.00f);
|
||||
const vec2<T> p_proj = v0 + tt * v1v0;
|
||||
return dot(p - p_proj, p - p_proj);
|
||||
}
|
||||
|
||||
// Backward pass for point to line distance in 2D.
|
||||
//
|
||||
// Args:
|
||||
// p: Coordinates of a point.
|
||||
// v0, v1: Coordinates of the end points of the line segment.
|
||||
// grad_dist: Upstream gradient for the distance.
|
||||
//
|
||||
// Returns:
|
||||
// tuple of gradients for each of the input points:
|
||||
// (vec2<T> grad_p, vec2<T> grad_v0, vec2<T> grad_v1)
|
||||
//
|
||||
template <typename T>
|
||||
inline std::tuple<vec2<T>, vec2<T>, vec2<T>> PointLineDistanceBackward(
|
||||
const vec2<T>& p,
|
||||
const vec2<T>& v0,
|
||||
const vec2<T>& v1,
|
||||
const T& grad_dist) {
|
||||
// Redo some of the forward pass calculations.
|
||||
const vec2<T> v1v0 = v1 - v0;
|
||||
const vec2<T> pv0 = p - v0;
|
||||
const T t_bot = dot(v1v0, v1v0);
|
||||
const T t_top = dot(v1v0, pv0);
|
||||
const T t = t_top / t_bot;
|
||||
const T tt = std::min(std::max(t, 0.00f), 1.00f);
|
||||
const vec2<T> p_proj = (1.0f - tt) * v0 + tt * v1;
|
||||
|
||||
const vec2<T> grad_v0 = grad_dist * (1.0f - tt) * 2.0f * (p_proj - p);
|
||||
const vec2<T> grad_v1 = grad_dist * tt * 2.0f * (p_proj - p);
|
||||
const vec2<T> grad_p = -1.0f * grad_dist * 2.0f * (p_proj - p);
|
||||
|
||||
return std::make_tuple(grad_p, grad_v0, grad_v1);
|
||||
}
|
||||
|
||||
// The forward pass for calculating the shortest distance between a point
|
||||
// and a triangle.
|
||||
// Ref: https://www.randygaul.net/2014/07/23/distance-point-to-line-segment/
|
||||
//
|
||||
// Args:
|
||||
// p: Coordinates of a point.
|
||||
// v0, v1, v2: Coordinates of the three triangle vertices.
|
||||
//
|
||||
// Returns:
|
||||
// shortest squared distance from a point to a triangle.
|
||||
//
|
||||
//
|
||||
template <typename T>
|
||||
T PointTriangleDistanceForward(
|
||||
const vec2<T>& p,
|
||||
const vec2<T>& v0,
|
||||
const vec2<T>& v1,
|
||||
const vec2<T>& v2) {
|
||||
// Compute distance of point to 3 edges of the triangle and return the
|
||||
// minimum value.
|
||||
const T e01_dist = PointLineDistanceForward(p, v0, v1);
|
||||
const T e02_dist = PointLineDistanceForward(p, v0, v2);
|
||||
const T e12_dist = PointLineDistanceForward(p, v1, v2);
|
||||
const T edge_dist = std::min(std::min(e01_dist, e02_dist), e12_dist);
|
||||
|
||||
return edge_dist;
|
||||
}
|
||||
|
||||
// Backward pass for point triangle distance.
|
||||
//
|
||||
// Args:
|
||||
// p: Coordinates of a point.
|
||||
// v0, v1, v2: Coordinates of the three triangle vertices.
|
||||
// grad_dist: Upstream gradient for the distance.
|
||||
//
|
||||
// Returns:
|
||||
// tuple of gradients for each of the triangle vertices:
|
||||
// (vec2<T> grad_v0, vec2<T> grad_v1, vec2<T> grad_v2)
|
||||
//
|
||||
template <typename T>
|
||||
inline std::tuple<vec2<T>, vec2<T>, vec2<T>, vec2<T>>
|
||||
PointTriangleDistanceBackward(
|
||||
const vec2<T>& p,
|
||||
const vec2<T>& v0,
|
||||
const vec2<T>& v1,
|
||||
const vec2<T>& v2,
|
||||
const T& grad_dist) {
|
||||
// Compute distance to all 3 edges of the triangle.
|
||||
const T e01_dist = PointLineDistanceForward(p, v0, v1);
|
||||
const T e02_dist = PointLineDistanceForward(p, v0, v2);
|
||||
const T e12_dist = PointLineDistanceForward(p, v1, v2);
|
||||
|
||||
// Initialize output tensors.
|
||||
vec2<T> grad_v0(0.0f, 0.0f);
|
||||
vec2<T> grad_v1(0.0f, 0.0f);
|
||||
vec2<T> grad_v2(0.0f, 0.0f);
|
||||
vec2<T> grad_p(0.0f, 0.0f);
|
||||
|
||||
// Find which edge is the closest and return PointLineDistanceBackward for
|
||||
// that edge.
|
||||
if (e01_dist <= e02_dist && e01_dist <= e12_dist) {
|
||||
// Closest edge is v1 - v0.
|
||||
auto grad_e01 = PointLineDistanceBackward(p, v0, v1, grad_dist);
|
||||
grad_p = std::get<0>(grad_e01);
|
||||
grad_v0 = std::get<1>(grad_e01);
|
||||
grad_v1 = std::get<2>(grad_e01);
|
||||
} else if (e02_dist <= e01_dist && e02_dist <= e12_dist) {
|
||||
// Closest edge is v2 - v0.
|
||||
auto grad_e02 = PointLineDistanceBackward(p, v0, v2, grad_dist);
|
||||
grad_p = std::get<0>(grad_e02);
|
||||
grad_v0 = std::get<1>(grad_e02);
|
||||
grad_v2 = std::get<2>(grad_e02);
|
||||
} else if (e12_dist <= e01_dist && e12_dist <= e02_dist) {
|
||||
// Closest edge is v2 - v1.
|
||||
auto grad_e12 = PointLineDistanceBackward(p, v1, v2, grad_dist);
|
||||
grad_p = std::get<0>(grad_e12);
|
||||
grad_v1 = std::get<1>(grad_e12);
|
||||
grad_v2 = std::get<2>(grad_e12);
|
||||
}
|
||||
|
||||
return std::make_tuple(grad_p, grad_v0, grad_v1, grad_v2);
|
||||
}
|
||||
218
pytorch3d/csrc/utils/index_utils.cuh
Normal file
218
pytorch3d/csrc/utils/index_utils.cuh
Normal file
@@ -0,0 +1,218 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
// This converts dynamic array lookups into static array lookups, for small
|
||||
// arrays up to size 32.
|
||||
//
|
||||
// Suppose we have a small thread-local array:
|
||||
//
|
||||
// float vals[10];
|
||||
//
|
||||
// Ideally we should only index this array using static indices:
|
||||
//
|
||||
// for (int i = 0; i < 10; ++i) vals[i] = i * i;
|
||||
//
|
||||
// If we do so, then the CUDA compiler may be able to place the array into
|
||||
// registers, which can have a big performance improvement. However if we
|
||||
// access the array dynamically, the the compiler may force the array into
|
||||
// local memory, which has the same latency as global memory.
|
||||
//
|
||||
// These functions convert dynamic array access into static array access
|
||||
// using a brute-force lookup table. It can be used like this:
|
||||
//
|
||||
// float vals[10];
|
||||
// int idx = 3;
|
||||
// float val = 3.14f;
|
||||
// RegisterIndexUtils<float, 10>::set(vals, idx, val);
|
||||
// float val2 = RegisterIndexUtils<float, 10>::get(vals, idx);
|
||||
//
|
||||
// The implementation is based on fbcuda/RegisterUtils.cuh:
|
||||
// https://github.com/facebook/fbcuda/blob/master/RegisterUtils.cuh
|
||||
// To avoid depending on the entire library, we just reimplement these two
|
||||
// functions. The fbcuda implementation is a bit more sophisticated, and uses
|
||||
// the preprocessor to generate switch statements that go up to N for each
|
||||
// value of N. We are lazy and just have a giant explicit switch statement.
|
||||
//
|
||||
// We might be able to use a template metaprogramming approach similar to
|
||||
// DispatchKernel1D for this. However DispatchKernel1D is intended to be used
|
||||
// for dispatching to the correct CUDA kernel on the host, while this is
|
||||
// is intended to run on the device. I was concerned that a metaprogramming
|
||||
// approach for this might lead to extra function calls at runtime if the
|
||||
// compiler fails to optimize them away, which could be very slow on device.
|
||||
// However I didn't actually benchmark or test this.
|
||||
template <typename T, int N>
|
||||
struct RegisterIndexUtils {
|
||||
__device__ __forceinline__ static T get(const T arr[N], int idx) {
|
||||
if (idx < 0 || idx >= N)
|
||||
return T();
|
||||
switch (idx) {
|
||||
case 0:
|
||||
return arr[0];
|
||||
case 1:
|
||||
return arr[1];
|
||||
case 2:
|
||||
return arr[2];
|
||||
case 3:
|
||||
return arr[3];
|
||||
case 4:
|
||||
return arr[4];
|
||||
case 5:
|
||||
return arr[5];
|
||||
case 6:
|
||||
return arr[6];
|
||||
case 7:
|
||||
return arr[7];
|
||||
case 8:
|
||||
return arr[8];
|
||||
case 9:
|
||||
return arr[9];
|
||||
case 10:
|
||||
return arr[10];
|
||||
case 11:
|
||||
return arr[11];
|
||||
case 12:
|
||||
return arr[12];
|
||||
case 13:
|
||||
return arr[13];
|
||||
case 14:
|
||||
return arr[14];
|
||||
case 15:
|
||||
return arr[15];
|
||||
case 16:
|
||||
return arr[16];
|
||||
case 17:
|
||||
return arr[17];
|
||||
case 18:
|
||||
return arr[18];
|
||||
case 19:
|
||||
return arr[19];
|
||||
case 20:
|
||||
return arr[20];
|
||||
case 21:
|
||||
return arr[21];
|
||||
case 22:
|
||||
return arr[22];
|
||||
case 23:
|
||||
return arr[23];
|
||||
case 24:
|
||||
return arr[24];
|
||||
case 25:
|
||||
return arr[25];
|
||||
case 26:
|
||||
return arr[26];
|
||||
case 27:
|
||||
return arr[27];
|
||||
case 28:
|
||||
return arr[28];
|
||||
case 29:
|
||||
return arr[29];
|
||||
case 30:
|
||||
return arr[30];
|
||||
case 31:
|
||||
return arr[31];
|
||||
};
|
||||
return T();
|
||||
}
|
||||
|
||||
__device__ __forceinline__ static void set(T arr[N], int idx, T val) {
|
||||
if (idx < 0 || idx >= N)
|
||||
return;
|
||||
switch (idx) {
|
||||
case 0:
|
||||
arr[0] = val;
|
||||
break;
|
||||
case 1:
|
||||
arr[1] = val;
|
||||
break;
|
||||
case 2:
|
||||
arr[2] = val;
|
||||
break;
|
||||
case 3:
|
||||
arr[3] = val;
|
||||
break;
|
||||
case 4:
|
||||
arr[4] = val;
|
||||
break;
|
||||
case 5:
|
||||
arr[5] = val;
|
||||
break;
|
||||
case 6:
|
||||
arr[6] = val;
|
||||
break;
|
||||
case 7:
|
||||
arr[7] = val;
|
||||
break;
|
||||
case 8:
|
||||
arr[8] = val;
|
||||
break;
|
||||
case 9:
|
||||
arr[9] = val;
|
||||
break;
|
||||
case 10:
|
||||
arr[10] = val;
|
||||
break;
|
||||
case 11:
|
||||
arr[11] = val;
|
||||
break;
|
||||
case 12:
|
||||
arr[12] = val;
|
||||
break;
|
||||
case 13:
|
||||
arr[13] = val;
|
||||
break;
|
||||
case 14:
|
||||
arr[14] = val;
|
||||
break;
|
||||
case 15:
|
||||
arr[15] = val;
|
||||
break;
|
||||
case 16:
|
||||
arr[16] = val;
|
||||
break;
|
||||
case 17:
|
||||
arr[17] = val;
|
||||
break;
|
||||
case 18:
|
||||
arr[18] = val;
|
||||
break;
|
||||
case 19:
|
||||
arr[19] = val;
|
||||
break;
|
||||
case 20:
|
||||
arr[20] = val;
|
||||
break;
|
||||
case 21:
|
||||
arr[21] = val;
|
||||
break;
|
||||
case 22:
|
||||
arr[22] = val;
|
||||
break;
|
||||
case 23:
|
||||
arr[23] = val;
|
||||
break;
|
||||
case 24:
|
||||
arr[24] = val;
|
||||
break;
|
||||
case 25:
|
||||
arr[25] = val;
|
||||
break;
|
||||
case 26:
|
||||
arr[26] = val;
|
||||
break;
|
||||
case 27:
|
||||
arr[27] = val;
|
||||
break;
|
||||
case 28:
|
||||
arr[28] = val;
|
||||
break;
|
||||
case 29:
|
||||
arr[29] = val;
|
||||
break;
|
||||
case 30:
|
||||
arr[30] = val;
|
||||
break;
|
||||
case 31:
|
||||
arr[31] = val;
|
||||
break;
|
||||
}
|
||||
}
|
||||
};
|
||||
159
pytorch3d/csrc/utils/mink.cuh
Normal file
159
pytorch3d/csrc/utils/mink.cuh
Normal file
@@ -0,0 +1,159 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#define MINK_H
|
||||
|
||||
#include "index_utils.cuh"
|
||||
|
||||
// A data structure to keep track of the smallest K keys seen so far as well
|
||||
// as their associated values, intended to be used in device code.
|
||||
// This data structure doesn't allocate any memory; keys and values are stored
|
||||
// in arrays passed to the constructor.
|
||||
//
|
||||
// The implementation is generic; it can be used for any key type that supports
|
||||
// the < operator, and can be used with any value type.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// float keys[K];
|
||||
// int values[K];
|
||||
// MinK<float, int> mink(keys, values, K);
|
||||
// for (...) {
|
||||
// // Produce some key and value from somewhere
|
||||
// mink.add(key, value);
|
||||
// }
|
||||
// mink.sort();
|
||||
//
|
||||
// Now keys and values store the smallest K keys seen so far and the values
|
||||
// associated to these keys:
|
||||
//
|
||||
// for (int k = 0; k < K; ++k) {
|
||||
// float key_k = keys[k];
|
||||
// int value_k = values[k];
|
||||
// }
|
||||
template <typename key_t, typename value_t>
|
||||
class MinK {
|
||||
public:
|
||||
// Constructor.
|
||||
//
|
||||
// Arguments:
|
||||
// keys: Array in which to store keys
|
||||
// values: Array in which to store values
|
||||
// K: How many values to keep track of
|
||||
__device__ MinK(key_t* keys, value_t* vals, int K)
|
||||
: keys(keys), vals(vals), K(K), _size(0) {}
|
||||
|
||||
// Try to add a new key and associated value to the data structure. If the key
|
||||
// is one of the smallest K seen so far then it will be kept; otherwise it
|
||||
// it will not be kept.
|
||||
//
|
||||
// This takes O(1) operations if the new key is not kept, or if the structure
|
||||
// currently contains fewer than K elements. Otherwise this takes O(K) time.
|
||||
//
|
||||
// Arguments:
|
||||
// key: The key to add
|
||||
// val: The value associated to the key
|
||||
__device__ __forceinline__ void add(const key_t& key, const value_t& val) {
|
||||
if (_size < K) {
|
||||
keys[_size] = key;
|
||||
vals[_size] = val;
|
||||
if (_size == 0 || key > max_key) {
|
||||
max_key = key;
|
||||
max_idx = _size;
|
||||
}
|
||||
_size++;
|
||||
} else if (key < max_key) {
|
||||
keys[max_idx] = key;
|
||||
vals[max_idx] = val;
|
||||
max_key = key;
|
||||
for (int k = 0; k < K; ++k) {
|
||||
key_t cur_key = keys[k];
|
||||
if (cur_key > max_key) {
|
||||
max_key = cur_key;
|
||||
max_idx = k;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get the number of items currently stored in the structure.
|
||||
// This takes O(1) time.
|
||||
__device__ __forceinline__ int size() {
|
||||
return _size;
|
||||
}
|
||||
|
||||
// Sort the items stored in the structure using bubble sort.
|
||||
// This takes O(K^2) time.
|
||||
__device__ __forceinline__ void sort() {
|
||||
for (int i = 0; i < _size - 1; ++i) {
|
||||
for (int j = 0; j < _size - i - 1; ++j) {
|
||||
if (keys[j + 1] < keys[j]) {
|
||||
key_t key = keys[j];
|
||||
value_t val = vals[j];
|
||||
keys[j] = keys[j + 1];
|
||||
vals[j] = vals[j + 1];
|
||||
keys[j + 1] = key;
|
||||
vals[j + 1] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
key_t* keys;
|
||||
value_t* vals;
|
||||
int K;
|
||||
int _size;
|
||||
key_t max_key;
|
||||
int max_idx;
|
||||
};
|
||||
|
||||
// This is a version of MinK that only touches the arrays using static indexing
|
||||
// via RegisterIndexUtils. If the keys and values are stored in thread-local
|
||||
// arrays, then this may allow the compiler to place them in registers for
|
||||
// fast access.
|
||||
//
|
||||
// This has the same API as RegisterMinK, but doesn't support sorting.
|
||||
// We found that sorting via RegisterIndexUtils gave very poor performance,
|
||||
// and suspect it may have prevented the compiler from placing the arrays
|
||||
// into registers.
|
||||
template <typename key_t, typename value_t, int K>
|
||||
class RegisterMinK {
|
||||
public:
|
||||
__device__ RegisterMinK(key_t* keys, value_t* vals)
|
||||
: keys(keys), vals(vals), _size(0) {}
|
||||
|
||||
__device__ __forceinline__ void add(const key_t& key, const value_t& val) {
|
||||
if (_size < K) {
|
||||
RegisterIndexUtils<key_t, K>::set(keys, _size, key);
|
||||
RegisterIndexUtils<value_t, K>::set(vals, _size, val);
|
||||
if (_size == 0 || key > max_key) {
|
||||
max_key = key;
|
||||
max_idx = _size;
|
||||
}
|
||||
_size++;
|
||||
} else if (key < max_key) {
|
||||
RegisterIndexUtils<key_t, K>::set(keys, max_idx, key);
|
||||
RegisterIndexUtils<value_t, K>::set(vals, max_idx, val);
|
||||
max_key = key;
|
||||
for (int k = 0; k < K; ++k) {
|
||||
key_t cur_key = RegisterIndexUtils<key_t, K>::get(keys, k);
|
||||
if (cur_key > max_key) {
|
||||
max_key = cur_key;
|
||||
max_idx = k;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int size() {
|
||||
return _size;
|
||||
}
|
||||
|
||||
private:
|
||||
key_t* keys;
|
||||
value_t* vals;
|
||||
int _size;
|
||||
key_t max_key;
|
||||
int max_idx;
|
||||
};
|
||||
11
pytorch3d/csrc/utils/pytorch3d_cutils.h
Normal file
11
pytorch3d/csrc/utils/pytorch3d_cutils.h
Normal file
@@ -0,0 +1,11 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x "must be a CUDA tensor.")
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
AT_ASSERTM(x.is_contiguous(), #x "must be contiguous.")
|
||||
#define CHECK_CONTIGUOUS_CUDA(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
59
pytorch3d/csrc/utils/vec2.h
Normal file
59
pytorch3d/csrc/utils/vec2.h
Normal file
@@ -0,0 +1,59 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include <type_traits>
|
||||
|
||||
// A fixed-sized vector with basic arithmetic operators useful for
|
||||
// representing 2D coordinates.
|
||||
// TODO: switch to Eigen if more functionality is needed.
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
std::is_same<T, double>::value || std::is_same<T, float>::value>>
|
||||
struct vec2 {
|
||||
T x, y;
|
||||
typedef T scalar_t;
|
||||
vec2(T x, T y) : x(x), y(y) {}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline vec2<T> operator+(const vec2<T>& a, const vec2<T>& b) {
|
||||
return vec2<T>(a.x + b.x, a.y + b.y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline vec2<T> operator-(const vec2<T>& a, const vec2<T>& b) {
|
||||
return vec2<T>(a.x - b.x, a.y - b.y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline vec2<T> operator*(const T a, const vec2<T>& b) {
|
||||
return vec2<T>(a * b.x, a * b.y);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline vec2<T> operator/(const vec2<T>& a, const T b) {
|
||||
if (b == 0.0) {
|
||||
AT_ERROR(
|
||||
"denominator in vec2 division is 0"); // prevent divide by 0 errors.
|
||||
}
|
||||
return vec2<T>(a.x / b, a.y / b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T dot(const vec2<T>& a, const vec2<T>& b) {
|
||||
return a.x * b.x + a.y * b.y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T norm(const vec2<T>& a, const vec2<T>& b) {
|
||||
const vec2<T> ba = b - a;
|
||||
return sqrt(dot(ba, ba));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const vec2<T>& v) {
|
||||
os << "vec2(" << v.x << ", " << v.y << ")";
|
||||
return os;
|
||||
}
|
||||
63
pytorch3d/csrc/utils/vec3.h
Normal file
63
pytorch3d/csrc/utils/vec3.h
Normal file
@@ -0,0 +1,63 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// A fixed-sized vector with basic arithmetic operators useful for
|
||||
// representing 3D coordinates.
|
||||
// TODO: switch to Eigen if more functionality is needed.
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
std::is_same<T, double>::value || std::is_same<T, float>::value>>
|
||||
struct vec3 {
|
||||
T x, y, z;
|
||||
typedef T scalar_t;
|
||||
vec3(T x, T y, T z) : x(x), y(y), z(z) {}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline vec3<T> operator+(const vec3<T>& a, const vec3<T>& b) {
|
||||
return vec3<T>(a.x + b.x, a.y + b.y, a.z + b.z);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline vec3<T> operator-(const vec3<T>& a, const vec3<T>& b) {
|
||||
return vec3<T>(a.x - b.x, a.y - b.y, a.z - b.z);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline vec3<T> operator/(const vec3<T>& a, const T b) {
|
||||
if (b == 0.0) {
|
||||
AT_ERROR(
|
||||
"denominator in vec3 division is 0"); // prevent divide by 0 errors.
|
||||
}
|
||||
return vec3<T>(a.x / b, a.y / b, a.z / b);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline vec3<T> operator*(const T a, const vec3<T>& b) {
|
||||
return vec3<T>(a * b.x, a * b.y, a * b.z);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline vec3<T> operator*(const vec3<T>& a, const vec3<T>& b) {
|
||||
return vec3<T>(a.x * b.x, a.y * b.y, a.z * b.z);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T dot(const vec3<T>& a, const vec3<T>& b) {
|
||||
return a.x * b.x + a.y * b.y + a.z * b.z;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline vec3<T> cross(const vec3<T>& a, const vec3<T>& b) {
|
||||
return vec3<T>(
|
||||
a.y * b.z - a.z * b.y, a.z * b.x - a.x * b.z, a.x * b.y - a.y * b.x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const vec3<T>& v) {
|
||||
os << "vec3(" << v.x << ", " << v.y << ", " << v.z << ")";
|
||||
return os;
|
||||
}
|
||||
44
pytorch3d/csrc/utils/warp_reduce.cuh
Normal file
44
pytorch3d/csrc/utils/warp_reduce.cuh
Normal file
@@ -0,0 +1,44 @@
|
||||
// Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
||||
|
||||
#include <float.h>
|
||||
#include <math.h>
|
||||
#include <cstdio>
|
||||
|
||||
// helper WarpReduce used in .cu files
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ void WarpReduce(
|
||||
volatile scalar_t* min_dists,
|
||||
volatile int64_t* min_idxs,
|
||||
const size_t tid) {
|
||||
// s = 32
|
||||
if (min_dists[tid] > min_dists[tid + 32]) {
|
||||
min_idxs[tid] = min_idxs[tid + 32];
|
||||
min_dists[tid] = min_dists[tid + 32];
|
||||
}
|
||||
// s = 16
|
||||
if (min_dists[tid] > min_dists[tid + 16]) {
|
||||
min_idxs[tid] = min_idxs[tid + 16];
|
||||
min_dists[tid] = min_dists[tid + 16];
|
||||
}
|
||||
// s = 8
|
||||
if (min_dists[tid] > min_dists[tid + 8]) {
|
||||
min_idxs[tid] = min_idxs[tid + 8];
|
||||
min_dists[tid] = min_dists[tid + 8];
|
||||
}
|
||||
// s = 4
|
||||
if (min_dists[tid] > min_dists[tid + 4]) {
|
||||
min_idxs[tid] = min_idxs[tid + 4];
|
||||
min_dists[tid] = min_dists[tid + 4];
|
||||
}
|
||||
// s = 2
|
||||
if (min_dists[tid] > min_dists[tid + 2]) {
|
||||
min_idxs[tid] = min_idxs[tid + 2];
|
||||
min_dists[tid] = min_dists[tid + 2];
|
||||
}
|
||||
// s = 1
|
||||
if (min_dists[tid] > min_dists[tid + 1]) {
|
||||
min_idxs[tid] = min_idxs[tid + 1];
|
||||
min_dists[tid] = min_dists[tid + 1];
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user