1 Commits

66 changed files with 245 additions and 596 deletions

View File

@@ -10,7 +10,7 @@
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
DIR=$(dirname "${DIR}")
if [[ -f "${DIR}/BUCK" ]]
if [[ -f "${DIR}/TARGETS" ]]
then
pyfmt "${DIR}"
else

View File

@@ -6,4 +6,4 @@
# pyre-unsafe
__version__ = "0.7.9"
__version__ = "0.7.8"

View File

@@ -32,9 +32,7 @@ __global__ void BallQueryKernel(
at::PackedTensorAccessor64<int64_t, 3, at::RestrictPtrTraits> idxs,
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> dists,
const int64_t K,
const float radius,
const float radius2,
const bool skip_points_outside_cube) {
const float radius2) {
const int64_t N = p1.size(0);
const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x);
const int64_t chunks_to_do = N * chunks_per_cloud;
@@ -53,19 +51,7 @@ __global__ void BallQueryKernel(
// Iterate over points in p2 until desired count is reached or
// all points have been considered
for (int64_t j = 0, count = 0; j < lengths2[n] && count < K; ++j) {
if (skip_points_outside_cube) {
bool is_within_radius = true;
// Filter when any one coordinate is already outside the radius
for (int d = 0; is_within_radius && d < D; ++d) {
scalar_t abs_diff = fabs(p1[n][i][d] - p2[n][j][d]);
is_within_radius = (abs_diff <= radius);
}
if (!is_within_radius) {
continue;
}
}
// Else, calculate the distance between the points and compare
// Calculate the distance between the points
scalar_t dist2 = 0.0;
for (int d = 0; d < D; ++d) {
scalar_t diff = p1[n][i][d] - p2[n][j][d];
@@ -91,8 +77,7 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
const at::Tensor& lengths1, // (N,)
const at::Tensor& lengths2, // (N,)
int K,
float radius,
bool skip_points_outside_cube) {
float radius) {
// Check inputs are on the same device
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
@@ -135,9 +120,7 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
idxs.packed_accessor64<int64_t, 3, at::RestrictPtrTraits>(),
dists.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
K_64,
radius,
radius2,
skip_points_outside_cube);
radius2);
}));
AT_CUDA_CHECK(cudaGetLastError());

View File

@@ -25,9 +25,6 @@
// within the radius
// radius: the radius around each point within which the neighbors need to be
// located
// skip_points_outside_cube: If true, reduce multiplications of float values
// by not explicitly calculating distances to points that fall outside the
// D-cube with side length (2*radius) centered at each point in p1.
//
// Returns:
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
@@ -49,8 +46,7 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const int K,
const float radius,
const bool skip_points_outside_cube);
const float radius);
// CUDA implementation
std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
@@ -59,8 +55,7 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
const at::Tensor& lengths1,
const at::Tensor& lengths2,
const int K,
const float radius,
const bool skip_points_outside_cube);
const float radius);
// Implementation which is exposed
// Note: the backward pass reuses the KNearestNeighborBackward kernel
@@ -70,8 +65,7 @@ inline std::tuple<at::Tensor, at::Tensor> BallQuery(
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
float radius,
bool skip_points_outside_cube) {
float radius) {
if (p1.is_cuda() || p2.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(p1);
@@ -82,20 +76,16 @@ inline std::tuple<at::Tensor, at::Tensor> BallQuery(
lengths1.contiguous(),
lengths2.contiguous(),
K,
radius,
skip_points_outside_cube);
radius);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(p1);
CHECK_CPU(p2);
return BallQueryCpu(
p1.contiguous(),
p2.contiguous(),
lengths1.contiguous(),
lengths2.contiguous(),
K,
radius,
skip_points_outside_cube);
radius);
}

View File

@@ -6,7 +6,6 @@
* LICENSE file in the root directory of this source tree.
*/
#include <math.h>
#include <torch/extension.h>
#include <tuple>
@@ -16,8 +15,7 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
const at::Tensor& lengths1,
const at::Tensor& lengths2,
int K,
float radius,
bool skip_points_outside_cube) {
float radius) {
const int N = p1.size(0);
const int P1 = p1.size(1);
const int D = p1.size(2);
@@ -39,16 +37,6 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
const int64_t length2 = lengths2_a[n];
for (int64_t i = 0; i < length1; ++i) {
for (int64_t j = 0, count = 0; j < length2 && count < K; ++j) {
if (skip_points_outside_cube) {
bool is_within_radius = true;
for (int d = 0; is_within_radius && d < D; ++d) {
float abs_diff = fabs(p1_a[n][i][d] - p2_a[n][j][d]);
is_within_radius = (abs_diff <= radius);
}
if (!is_within_radius) {
continue;
}
}
float dist2 = 0;
for (int d = 0; d < D; ++d) {
float diff = p1_a[n][i][d] - p2_a[n][j][d];

View File

@@ -98,11 +98,6 @@ at::Tensor SigmoidAlphaBlendBackward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(distances);
CHECK_CPU(pix_to_face);
CHECK_CPU(alphas);
CHECK_CPU(grad_alphas);
return SigmoidAlphaBlendBackwardCpu(
grad_alphas, alphas, distances, pix_to_face, sigma);
}

View File

@@ -33,11 +33,11 @@ __global__ void alphaCompositeCudaForwardKernel(
const int64_t W = points_idx.size(3);
// Get the batch and index
const auto batch = blockIdx.x;
const int batch = blockIdx.x;
const int num_pixels = C * H * W;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
// Iterate over each feature in each pixel
for (int pid = tid; pid < num_pixels; pid += num_threads) {
@@ -83,11 +83,11 @@ __global__ void alphaCompositeCudaBackwardKernel(
const int64_t W = points_idx.size(3);
// Get the batch and index
const auto batch = blockIdx.x;
const int batch = blockIdx.x;
const int num_pixels = C * H * W;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
// Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size

View File

@@ -74,9 +74,6 @@ torch::Tensor alphaCompositeForward(
AT_ERROR("Not compiled with GPU support");
#endif
} else {
CHECK_CPU(features);
CHECK_CPU(alphas);
CHECK_CPU(points_idx);
return alphaCompositeCpuForward(features, alphas, points_idx);
}
}
@@ -104,11 +101,6 @@ std::tuple<torch::Tensor, torch::Tensor> alphaCompositeBackward(
AT_ERROR("Not compiled with GPU support");
#endif
} else {
CHECK_CPU(grad_outputs);
CHECK_CPU(features);
CHECK_CPU(alphas);
CHECK_CPU(points_idx);
return alphaCompositeCpuBackward(
grad_outputs, features, alphas, points_idx);
}

View File

@@ -33,11 +33,11 @@ __global__ void weightedSumNormCudaForwardKernel(
const int64_t W = points_idx.size(3);
// Get the batch and index
const auto batch = blockIdx.x;
const int batch = blockIdx.x;
const int num_pixels = C * H * W;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
// Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size
@@ -96,11 +96,11 @@ __global__ void weightedSumNormCudaBackwardKernel(
const int64_t W = points_idx.size(3);
// Get the batch and index
const auto batch = blockIdx.x;
const int batch = blockIdx.x;
const int num_pixels = C * W * H;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
// Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size

View File

@@ -73,10 +73,6 @@ torch::Tensor weightedSumNormForward(
AT_ERROR("Not compiled with GPU support");
#endif
} else {
CHECK_CPU(features);
CHECK_CPU(alphas);
CHECK_CPU(points_idx);
return weightedSumNormCpuForward(features, alphas, points_idx);
}
}
@@ -104,11 +100,6 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumNormBackward(
AT_ERROR("Not compiled with GPU support");
#endif
} else {
CHECK_CPU(grad_outputs);
CHECK_CPU(features);
CHECK_CPU(alphas);
CHECK_CPU(points_idx);
return weightedSumNormCpuBackward(
grad_outputs, features, alphas, points_idx);
}

View File

@@ -31,11 +31,11 @@ __global__ void weightedSumCudaForwardKernel(
const int64_t W = points_idx.size(3);
// Get the batch and index
const auto batch = blockIdx.x;
const int batch = blockIdx.x;
const int num_pixels = C * H * W;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
// Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size
@@ -78,11 +78,11 @@ __global__ void weightedSumCudaBackwardKernel(
const int64_t W = points_idx.size(3);
// Get the batch and index
const auto batch = blockIdx.x;
const int batch = blockIdx.x;
const int num_pixels = C * H * W;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
// Iterate over each pixel to compute the contribution to the
// gradient for the features and weights

View File

@@ -72,9 +72,6 @@ torch::Tensor weightedSumForward(
AT_ERROR("Not compiled with GPU support");
#endif
} else {
CHECK_CPU(features);
CHECK_CPU(alphas);
CHECK_CPU(points_idx);
return weightedSumCpuForward(features, alphas, points_idx);
}
}
@@ -101,11 +98,6 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumBackward(
AT_ERROR("Not compiled with GPU support");
#endif
} else {
CHECK_CPU(grad_outputs);
CHECK_CPU(features);
CHECK_CPU(alphas);
CHECK_CPU(points_idx);
return weightedSumCpuBackward(grad_outputs, features, alphas, points_idx);
}
}

View File

@@ -8,6 +8,7 @@
// clang-format off
#include "./pulsar/global.h" // Include before <torch/extension.h>.
#include <torch/extension.h>
// clang-format on
#include "./pulsar/pytorch/renderer.h"
#include "./pulsar/pytorch/tensor_util.h"
@@ -105,16 +106,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<
pulsar::pytorch::Renderer,
std::shared_ptr<pulsar::pytorch::Renderer>>(m, "PulsarRenderer")
.def(
py::init<
const uint&,
const uint&,
const uint&,
const bool&,
const bool&,
const float&,
const uint&,
const uint&>())
.def(py::init<
const uint&,
const uint&,
const uint&,
const bool&,
const bool&,
const float&,
const uint&,
const uint&>())
.def(
"__eq__",
[](const pulsar::pytorch::Renderer& a,

View File

@@ -60,8 +60,6 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(verts);
CHECK_CPU(faces);
return FaceAreasNormalsForwardCpu(verts, faces);
}
@@ -82,9 +80,5 @@ at::Tensor FaceAreasNormalsBackward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(grad_areas);
CHECK_CPU(grad_normals);
CHECK_CPU(verts);
CHECK_CPU(faces);
return FaceAreasNormalsBackwardCpu(grad_areas, grad_normals, verts, faces);
}

View File

@@ -20,14 +20,14 @@ __global__ void GatherScatterCudaKernel(
const size_t V,
const size_t D,
const size_t E) {
const auto tid = threadIdx.x;
const int tid = threadIdx.x;
// Reverse the vertex order if backward.
const int v0_idx = backward ? 1 : 0;
const int v1_idx = backward ? 0 : 1;
// Edges are split evenly across the blocks.
for (auto e = blockIdx.x; e < E; e += gridDim.x) {
for (int e = blockIdx.x; e < E; e += gridDim.x) {
// Get indices of vertices which form the edge.
const int64_t v0 = edges[2 * e + v0_idx];
const int64_t v1 = edges[2 * e + v1_idx];
@@ -35,7 +35,7 @@ __global__ void GatherScatterCudaKernel(
// Split vertex features evenly across threads.
// This implementation will be quite wasteful when D<128 since there will be
// a lot of threads doing nothing.
for (auto d = tid; d < D; d += blockDim.x) {
for (int d = tid; d < D; d += blockDim.x) {
const float val = input[v1 * D + d];
float* address = output + v0 * D + d;
atomicAdd(address, val);

View File

@@ -53,7 +53,5 @@ at::Tensor GatherScatter(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(input);
CHECK_CPU(edges);
return GatherScatterCpu(input, edges, directed, backward);
}

View File

@@ -20,8 +20,8 @@ __global__ void InterpFaceAttrsForwardKernel(
const size_t P,
const size_t F,
const size_t D) {
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
const auto num_threads = blockDim.x * gridDim.x;
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
for (int pd = tid; pd < P * D; pd += num_threads) {
const int p = pd / D;
const int d = pd % D;
@@ -93,8 +93,8 @@ __global__ void InterpFaceAttrsBackwardKernel(
const size_t P,
const size_t F,
const size_t D) {
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
const auto num_threads = blockDim.x * gridDim.x;
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
for (int pd = tid; pd < P * D; pd += num_threads) {
const int p = pd / D;
const int d = pd % D;

View File

@@ -57,8 +57,6 @@ at::Tensor InterpFaceAttrsForward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(face_attrs);
CHECK_CPU(barycentric_coords);
return InterpFaceAttrsForwardCpu(pix_to_face, barycentric_coords, face_attrs);
}
@@ -108,9 +106,6 @@ std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(face_attrs);
CHECK_CPU(barycentric_coords);
CHECK_CPU(grad_pix_attrs);
return InterpFaceAttrsBackwardCpu(
pix_to_face, barycentric_coords, face_attrs, grad_pix_attrs);
}

View File

@@ -44,7 +44,5 @@ inline std::tuple<at::Tensor, at::Tensor> IoUBox3D(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(boxes1);
CHECK_CPU(boxes2);
return IoUBox3DCpu(boxes1.contiguous(), boxes2.contiguous());
}

View File

@@ -74,8 +74,6 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(p1);
CHECK_CPU(p2);
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K);
}
@@ -142,8 +140,6 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(p1);
CHECK_CPU(p2);
return KNearestNeighborBackwardCpu(
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
}

View File

@@ -58,6 +58,5 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubes(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(vol);
return MarchingCubesCpu(vol.contiguous(), isolevel);
}

View File

@@ -88,8 +88,6 @@ at::Tensor PackedToPadded(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(inputs_packed);
CHECK_CPU(first_idxs);
return PackedToPaddedCpu(inputs_packed, first_idxs, max_size);
}
@@ -107,7 +105,5 @@ at::Tensor PaddedToPacked(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(inputs_padded);
CHECK_CPU(first_idxs);
return PaddedToPackedCpu(inputs_padded, first_idxs, num_inputs);
}

View File

@@ -174,8 +174,8 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceForwardCpu(
at::Tensor idxs = at::zeros({A_N,}, as_first_idx.options());
// clang-format on
auto as_a = as.accessor<float, H1 == 1 ? 2 : 3>();
auto bs_a = bs.accessor<float, H2 == 1 ? 2 : 3>();
auto as_a = as.accessor < float, H1 == 1 ? 2 : 3 > ();
auto bs_a = bs.accessor < float, H2 == 1 ? 2 : 3 > ();
auto as_first_idx_a = as_first_idx.accessor<int64_t, 1>();
auto bs_first_idx_a = bs_first_idx.accessor<int64_t, 1>();
auto dists_a = dists.accessor<float, 1>();
@@ -230,10 +230,10 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceBackwardCpu(
at::Tensor grad_as = at::zeros_like(as);
at::Tensor grad_bs = at::zeros_like(bs);
auto as_a = as.accessor<float, H1 == 1 ? 2 : 3>();
auto bs_a = bs.accessor<float, H2 == 1 ? 2 : 3>();
auto grad_as_a = grad_as.accessor<float, H1 == 1 ? 2 : 3>();
auto grad_bs_a = grad_bs.accessor<float, H2 == 1 ? 2 : 3>();
auto as_a = as.accessor < float, H1 == 1 ? 2 : 3 > ();
auto bs_a = bs.accessor < float, H2 == 1 ? 2 : 3 > ();
auto grad_as_a = grad_as.accessor < float, H1 == 1 ? 2 : 3 > ();
auto grad_bs_a = grad_bs.accessor < float, H2 == 1 ? 2 : 3 > ();
auto idx_bs_a = idx_bs.accessor<int64_t, 1>();
auto grad_dists_a = grad_dists.accessor<float, 1>();

View File

@@ -110,7 +110,7 @@ __global__ void DistanceForwardKernel(
__syncthreads();
// Perform reduction in shared memory.
for (auto s = blockDim.x / 2; s > 32; s >>= 1) {
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
if (tid < s) {
if (min_dists[tid] > min_dists[tid + s]) {
min_dists[tid] = min_dists[tid + s];
@@ -502,8 +502,8 @@ __global__ void PointFaceArrayForwardKernel(
const float3* tris_f3 = (float3*)tris;
// Parallelize over P * S computations
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int t_i = tid; t_i < P * T; t_i += num_threads) {
const int t = t_i / P; // segment index.
@@ -576,8 +576,8 @@ __global__ void PointFaceArrayBackwardKernel(
const float3* tris_f3 = (float3*)tris;
// Parallelize over P * S computations
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int t_i = tid; t_i < P * T; t_i += num_threads) {
const int t = t_i / P; // triangle index.
@@ -683,8 +683,8 @@ __global__ void PointEdgeArrayForwardKernel(
float3* segms_f3 = (float3*)segms;
// Parallelize over P * S computations
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
const int s = t_i / P; // segment index.
@@ -752,8 +752,8 @@ __global__ void PointEdgeArrayBackwardKernel(
float3* segms_f3 = (float3*)segms;
// Parallelize over P * S computations
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
const int s = t_i / P; // segment index.

View File

@@ -88,10 +88,6 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(points);
CHECK_CPU(points_first_idx);
CHECK_CPU(tris);
CHECK_CPU(tris_first_idx);
return PointFaceDistanceForwardCpu(
points, points_first_idx, tris, tris_first_idx, min_triangle_area);
}
@@ -147,10 +143,6 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(points);
CHECK_CPU(tris);
CHECK_CPU(idx_points);
CHECK_CPU(grad_dists);
return PointFaceDistanceBackwardCpu(
points, tris, idx_points, grad_dists, min_triangle_area);
}
@@ -229,10 +221,6 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(points);
CHECK_CPU(points_first_idx);
CHECK_CPU(tris);
CHECK_CPU(tris_first_idx);
return FacePointDistanceForwardCpu(
points, points_first_idx, tris, tris_first_idx, min_triangle_area);
}
@@ -289,10 +277,6 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(points);
CHECK_CPU(tris);
CHECK_CPU(idx_tris);
CHECK_CPU(grad_dists);
return FacePointDistanceBackwardCpu(
points, tris, idx_tris, grad_dists, min_triangle_area);
}
@@ -362,10 +346,6 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(points);
CHECK_CPU(points_first_idx);
CHECK_CPU(segms);
CHECK_CPU(segms_first_idx);
return PointEdgeDistanceForwardCpu(
points, points_first_idx, segms, segms_first_idx, max_points);
}
@@ -416,10 +396,6 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(points);
CHECK_CPU(segms);
CHECK_CPU(idx_points);
CHECK_CPU(grad_dists);
return PointEdgeDistanceBackwardCpu(points, segms, idx_points, grad_dists);
}
@@ -488,10 +464,6 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(points);
CHECK_CPU(points_first_idx);
CHECK_CPU(segms);
CHECK_CPU(segms_first_idx);
return EdgePointDistanceForwardCpu(
points, points_first_idx, segms, segms_first_idx, max_segms);
}
@@ -542,10 +514,6 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(points);
CHECK_CPU(segms);
CHECK_CPU(idx_segms);
CHECK_CPU(grad_dists);
return EdgePointDistanceBackwardCpu(points, segms, idx_segms, grad_dists);
}
@@ -599,8 +567,6 @@ torch::Tensor PointFaceArrayDistanceForward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(points);
CHECK_CPU(tris);
return PointFaceArrayDistanceForwardCpu(points, tris, min_triangle_area);
}
@@ -647,9 +613,6 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(points);
CHECK_CPU(tris);
CHECK_CPU(grad_dists);
return PointFaceArrayDistanceBackwardCpu(
points, tris, grad_dists, min_triangle_area);
}
@@ -698,8 +661,6 @@ torch::Tensor PointEdgeArrayDistanceForward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(points);
CHECK_CPU(segms);
return PointEdgeArrayDistanceForwardCpu(points, segms);
}
@@ -742,8 +703,5 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(points);
CHECK_CPU(segms);
CHECK_CPU(grad_dists);
return PointEdgeArrayDistanceBackwardCpu(points, segms, grad_dists);
}

View File

@@ -104,12 +104,6 @@ inline void PointsToVolumesForward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(points_3d);
CHECK_CPU(points_features);
CHECK_CPU(volume_densities);
CHECK_CPU(volume_features);
CHECK_CPU(grid_sizes);
CHECK_CPU(mask);
PointsToVolumesForwardCpu(
points_3d,
points_features,
@@ -189,14 +183,6 @@ inline void PointsToVolumesBackward(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(points_3d);
CHECK_CPU(points_features);
CHECK_CPU(grid_sizes);
CHECK_CPU(mask);
CHECK_CPU(grad_volume_densities);
CHECK_CPU(grad_volume_features);
CHECK_CPU(grad_points_3d);
CHECK_CPU(grad_points_features);
PointsToVolumesBackwardCpu(
points_3d,
points_features,

View File

@@ -15,8 +15,8 @@
#endif
#if defined(_WIN64) || defined(_WIN32)
using uint = unsigned int;
using ushort = unsigned short;
#define uint unsigned int
#define ushort unsigned short
#endif
#include "./logging.h" // <- include before torch/extension.h

View File

@@ -417,7 +417,7 @@ __device__ static float atomicMin(float* address, float val) {
(OUT_PTR), \
(NUM_SELECTED_PTR), \
(NUM_ITEMS), \
(STREAM));
stream = (STREAM));
#define COPY_HOST_DEV(PTR_D, PTR_H, TYPE, SIZE) \
HANDLECUDA(cudaMemcpy( \

View File

@@ -357,11 +357,11 @@ void MAX_WS(
//
//
#define END_PARALLEL() \
end_parallel:; \
end_parallel :; \
}
#define END_PARALLEL_NORET() }
#define END_PARALLEL_2D() \
end_parallel:; \
end_parallel :; \
} \
}
#define END_PARALLEL_2D_NORET() \

View File

@@ -70,6 +70,11 @@ struct CamGradInfo {
float3 pixel_dir_y;
};
// TODO: remove once https://github.com/NVlabs/cub/issues/172 is resolved.
struct IntWrapper {
int val;
};
} // namespace pulsar
#endif

View File

@@ -149,6 +149,11 @@ IHD CamGradInfo operator*(const CamGradInfo& a, const float& b) {
return res;
}
IHD IntWrapper operator+(const IntWrapper& a, const IntWrapper& b) {
IntWrapper res;
res.val = a.val + b.val;
return res;
}
} // namespace pulsar
#endif

View File

@@ -155,8 +155,8 @@ void backward(
stream);
CHECKLAUNCH();
SUM_WS(
self->ids_sorted_d,
self->n_grad_contributions_d,
(IntWrapper*)(self->ids_sorted_d),
(IntWrapper*)(self->n_grad_contributions_d),
static_cast<int>(num_balls),
self->workspace_d,
self->workspace_size,

View File

@@ -52,7 +52,7 @@ HOST void construct(
self->cam.film_width = width;
self->cam.film_height = height;
self->max_num_balls = max_num_balls;
MALLOC(self->result_d, float, width * height * n_channels);
MALLOC(self->result_d, float, width* height* n_channels);
self->cam.orthogonal_projection = orthogonal_projection;
self->cam.right_handed = right_handed_system;
self->cam.background_normalization_depth = background_normalization_depth;
@@ -93,7 +93,7 @@ HOST void construct(
MALLOC(self->di_sorted_d, DrawInfo, max_num_balls);
MALLOC(self->region_flags_d, char, max_num_balls);
MALLOC(self->num_selected_d, size_t, 1);
MALLOC(self->forw_info_d, float, width * height * (3 + 2 * n_track));
MALLOC(self->forw_info_d, float, width* height * (3 + 2 * n_track));
MALLOC(self->min_max_pixels_d, IntersectInfo, 1);
MALLOC(self->grad_pos_d, float3, max_num_balls);
MALLOC(self->grad_col_d, float, max_num_balls* n_channels);

View File

@@ -255,7 +255,7 @@ GLOBAL void calc_signature(
* for every iteration through the loading loop every thread could add a
* 'hit' to the buffer.
*/
#define RENDER_BUFFER_SIZE RENDER_BLOCK_SIZE * RENDER_BLOCK_SIZE * 2
#define RENDER_BUFFER_SIZE RENDER_BLOCK_SIZE* RENDER_BLOCK_SIZE * 2
/**
* The threshold after which the spheres that are in the render buffer
* are rendered and the buffer is flushed.

View File

@@ -6,6 +6,9 @@
* LICENSE file in the root directory of this source tree.
*/
#include "./global.h"
#include "./logging.h"
/**
* A compilation unit to provide warnings about the code and avoid
* repeated messages.

View File

@@ -25,7 +25,7 @@ class BitMask {
// Use all threads in the current block to clear all bits of this BitMask
__device__ void block_clear() {
for (auto i = threadIdx.x; i < H * W * D; i += blockDim.x) {
for (int i = threadIdx.x; i < H * W * D; i += blockDim.x) {
data[i] = 0;
}
__syncthreads();

View File

@@ -23,8 +23,8 @@ __global__ void TriangleBoundingBoxKernel(
const float blur_radius,
float* bboxes, // (4, F)
bool* skip_face) { // (F,)
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto num_threads = blockDim.x * gridDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = blockDim.x * gridDim.x;
const float sqrt_radius = sqrt(blur_radius);
for (int f = tid; f < F; f += num_threads) {
const float v0x = face_verts[f * 9 + 0 * 3 + 0];
@@ -56,8 +56,8 @@ __global__ void PointBoundingBoxKernel(
const int P,
float* bboxes, // (4, P)
bool* skip_points) {
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto num_threads = blockDim.x * gridDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = blockDim.x * gridDim.x;
for (int p = tid; p < P; p += num_threads) {
const float x = points[p * 3 + 0];
const float y = points[p * 3 + 1];
@@ -113,7 +113,7 @@ __global__ void RasterizeCoarseCudaKernel(
const int chunks_per_batch = 1 + (E - 1) / chunk_size;
const int num_chunks = N * chunks_per_batch;
for (auto chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
const int batch_idx = chunk / chunks_per_batch; // batch index
const int chunk_idx = chunk % chunks_per_batch;
const int elem_chunk_start_idx = chunk_idx * chunk_size;
@@ -123,7 +123,7 @@ __global__ void RasterizeCoarseCudaKernel(
const int64_t elem_stop_idx = elem_start_idx + elems_per_batch[batch_idx];
// Have each thread handle a different face within the chunk
for (auto e = threadIdx.x; e < chunk_size; e += blockDim.x) {
for (int e = threadIdx.x; e < chunk_size; e += blockDim.x) {
const int e_idx = elem_chunk_start_idx + e;
// Check that we are still within the same element of the batch
@@ -170,7 +170,7 @@ __global__ void RasterizeCoarseCudaKernel(
// Now we have processed every elem in the current chunk. We need to
// count the number of elems in each bin so we can write the indices
// out to global memory. We have each thread handle a different bin.
for (auto byx = threadIdx.x; byx < num_bins_y * num_bins_x;
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
byx += blockDim.x) {
const int by = byx / num_bins_x;
const int bx = byx % num_bins_x;

View File

@@ -260,8 +260,8 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
float* pix_dists,
float* bary) {
// Simple version: One thread per output pixel
auto num_threads = gridDim.x * blockDim.x;
auto tid = blockDim.x * blockIdx.x + threadIdx.x;
int num_threads = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < N * H * W; i += num_threads) {
// Convert linear index to 3D index
@@ -446,8 +446,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
// Parallelize over each pixel in images of
// size H * W, for each image in the batch of size N.
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int t_i = tid; t_i < N * H * W; t_i += num_threads) {
// Convert linear index to 3D index
@@ -650,8 +650,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
) {
// This can be more than H * W if H or W are not divisible by bin_size.
int num_pixels = N * BH * BW * bin_size * bin_size;
auto num_threads = gridDim.x * blockDim.x;
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int num_threads = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int pid = tid; pid < num_pixels; pid += num_threads) {
// Convert linear index into bin and pixel indices. We make the within

View File

@@ -138,9 +138,6 @@ RasterizeMeshesNaive(
AT_ERROR("Not compiled with GPU support");
#endif
} else {
CHECK_CPU(face_verts);
CHECK_CPU(mesh_to_face_first_idx);
CHECK_CPU(num_faces_per_mesh);
return RasterizeMeshesNaiveCpu(
face_verts,
mesh_to_face_first_idx,
@@ -235,11 +232,6 @@ torch::Tensor RasterizeMeshesBackward(
AT_ERROR("Not compiled with GPU support");
#endif
} else {
CHECK_CPU(face_verts);
CHECK_CPU(pix_to_face);
CHECK_CPU(grad_zbuf);
CHECK_CPU(grad_bary);
CHECK_CPU(grad_dists);
return RasterizeMeshesBackwardCpu(
face_verts,
pix_to_face,
@@ -314,9 +306,6 @@ torch::Tensor RasterizeMeshesCoarse(
AT_ERROR("Not compiled with GPU support");
#endif
} else {
CHECK_CPU(face_verts);
CHECK_CPU(mesh_to_face_first_idx);
CHECK_CPU(num_faces_per_mesh);
return RasterizeMeshesCoarseCpu(
face_verts,
mesh_to_face_first_idx,
@@ -434,8 +423,6 @@ RasterizeMeshesFine(
AT_ERROR("Not compiled with GPU support");
#endif
} else {
CHECK_CPU(face_verts);
CHECK_CPU(bin_faces);
AT_ERROR("NOT IMPLEMENTED");
}
}

View File

@@ -97,8 +97,8 @@ __global__ void RasterizePointsNaiveCudaKernel(
float* zbuf, // (N, H, W, K)
float* pix_dists) { // (N, H, W, K)
// Simple version: One thread per output pixel
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockDim.x * blockIdx.x + threadIdx.x;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < N * H * W; i += num_threads) {
// Convert linear index to 3D index
const int n = i / (H * W); // Batch index
@@ -237,8 +237,8 @@ __global__ void RasterizePointsFineCudaKernel(
float* pix_dists) { // (N, H, W, K)
// This can be more than H * W if H or W are not divisible by bin_size.
const int num_pixels = N * BH * BW * bin_size * bin_size;
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int pid = tid; pid < num_pixels; pid += num_threads) {
// Convert linear index into bin and pixel indices. We make the within
@@ -376,8 +376,8 @@ __global__ void RasterizePointsBackwardCudaKernel(
float* grad_points) { // (P, 3)
// Parallelized over each of K points per pixel, for each pixel in images of
// size H * W, for each image in the batch of size N.
auto num_threads = gridDim.x * blockDim.x;
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
int num_threads = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = tid; i < N * H * W * K; i += num_threads) {
// const int n = i / (H * W * K); // batch index (not needed).
const int yxk = i % (H * W * K);

View File

@@ -91,10 +91,6 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
AT_ERROR("Not compiled with GPU support");
#endif
} else {
CHECK_CPU(points);
CHECK_CPU(cloud_to_packed_first_idx);
CHECK_CPU(num_points_per_cloud);
CHECK_CPU(radius);
return RasterizePointsNaiveCpu(
points,
cloud_to_packed_first_idx,
@@ -170,10 +166,6 @@ torch::Tensor RasterizePointsCoarse(
AT_ERROR("Not compiled with GPU support");
#endif
} else {
CHECK_CPU(points);
CHECK_CPU(cloud_to_packed_first_idx);
CHECK_CPU(num_points_per_cloud);
CHECK_CPU(radius);
return RasterizePointsCoarseCpu(
points,
cloud_to_packed_first_idx,
@@ -240,8 +232,6 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
AT_ERROR("Not compiled with GPU support");
#endif
} else {
CHECK_CPU(points);
CHECK_CPU(bin_points);
AT_ERROR("NOT IMPLEMENTED");
}
}
@@ -294,10 +284,6 @@ torch::Tensor RasterizePointsBackward(
AT_ERROR("Not compiled with GPU support");
#endif
} else {
CHECK_CPU(points);
CHECK_CPU(idxs);
CHECK_CPU(grad_zbuf);
CHECK_CPU(grad_dists);
return RasterizePointsBackwardCpu(points, idxs, grad_zbuf, grad_dists);
}
}

View File

@@ -107,8 +107,7 @@ at::Tensor FarthestPointSamplingCuda(
const at::Tensor& points, // (N, P, 3)
const at::Tensor& lengths, // (N,)
const at::Tensor& K, // (N,)
const at::Tensor& start_idxs,
const int64_t max_K_known = -1) {
const at::Tensor& start_idxs) {
// Check inputs are on the same device
at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2},
k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4};
@@ -130,12 +129,7 @@ at::Tensor FarthestPointSamplingCuda(
const int64_t N = points.size(0);
const int64_t P = points.size(1);
int64_t max_K;
if (max_K_known > 0) {
max_K = max_K_known;
} else {
max_K = at::max(K).item<int64_t>();
}
const int64_t max_K = at::max(K).item<int64_t>();
// Initialize the output tensor with the sampled indices
auto idxs = at::full({N, max_K}, -1, lengths.options());

View File

@@ -43,8 +43,7 @@ at::Tensor FarthestPointSamplingCuda(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const at::Tensor& start_idxs,
const int64_t max_K_known = -1);
const at::Tensor& start_idxs);
at::Tensor FarthestPointSamplingCpu(
const at::Tensor& points,
@@ -57,23 +56,17 @@ at::Tensor FarthestPointSampling(
const at::Tensor& points,
const at::Tensor& lengths,
const at::Tensor& K,
const at::Tensor& start_idxs,
const int64_t max_K_known = -1) {
const at::Tensor& start_idxs) {
if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
#ifdef WITH_CUDA
CHECK_CUDA(points);
CHECK_CUDA(lengths);
CHECK_CUDA(K);
CHECK_CUDA(start_idxs);
return FarthestPointSamplingCuda(
points, lengths, K, start_idxs, max_K_known);
return FarthestPointSamplingCuda(points, lengths, K, start_idxs);
#else
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(points);
CHECK_CPU(lengths);
CHECK_CPU(K);
CHECK_CPU(start_idxs);
return FarthestPointSamplingCpu(points, lengths, K, start_idxs);
}

View File

@@ -71,8 +71,6 @@ inline void SamplePdf(
AT_ERROR("Not compiled with GPU support.");
#endif
}
CHECK_CPU(weights);
CHECK_CPU(outputs);
CHECK_CONTIGUOUS(outputs);
SamplePdfCpu(bins, weights, outputs, eps);
}

View File

@@ -99,7 +99,8 @@ namespace {
// and increment it via template recursion until it is equal to the run-time
// argument N.
template <
template <typename, int64_t> class Kernel,
template <typename, int64_t>
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
@@ -123,7 +124,8 @@ struct DispatchKernelHelper1D {
// 1D dispatch: Specialization when curN == maxN
// We need this base case to avoid infinite template recursion.
template <
template <typename, int64_t> class Kernel,
template <typename, int64_t>
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
@@ -143,7 +145,8 @@ struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
// the run-time values of N and M, at which point we dispatch to the run
// method of the kernel.
template <
template <typename, int64_t, int64_t> class Kernel,
template <typename, int64_t, int64_t>
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
@@ -200,7 +203,8 @@ struct DispatchKernelHelper2D {
// 2D dispatch, specialization for curN == maxN
template <
template <typename, int64_t, int64_t> class Kernel,
template <typename, int64_t, int64_t>
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
@@ -239,7 +243,8 @@ struct DispatchKernelHelper2D<
// 2D dispatch, specialization for curM == maxM
template <
template <typename, int64_t, int64_t> class Kernel,
template <typename, int64_t, int64_t>
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
@@ -278,7 +283,8 @@ struct DispatchKernelHelper2D<
// 2D dispatch, specialization for curN == maxN, curM == maxM
template <
template <typename, int64_t, int64_t> class Kernel,
template <typename, int64_t, int64_t>
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
@@ -307,7 +313,8 @@ struct DispatchKernelHelper2D<
// This is the function we expect users to call to dispatch to 1D functions
template <
template <typename, int64_t> class Kernel,
template <typename, int64_t>
class Kernel,
typename T,
int64_t minN,
int64_t maxN,
@@ -323,7 +330,8 @@ void DispatchKernel1D(const int64_t N, Args... args) {
// This is the function we expect users to call to dispatch to 2D functions
template <
template <typename, int64_t, int64_t> class Kernel,
template <typename, int64_t, int64_t>
class Kernel,
typename T,
int64_t minN,
int64_t maxN,

View File

@@ -15,7 +15,3 @@
#define CHECK_CONTIGUOUS_CUDA(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_CPU(x) \
TORCH_CHECK( \
x.device().type() == torch::kCPU, \
"Cannot use CPU implementation: " #x " not on CPU.")

View File

@@ -755,7 +755,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
if pick_sequences:
old_len = len(eval_batches)
eval_batches = [b for b in eval_batches if b[0][0] in pick_sequences]
logger.warning(
logger.warn(
f"Picked eval batches by sequence/cat: {old_len} -> {len(eval_batches)}"
)
@@ -763,7 +763,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
old_len = len(eval_batches)
exclude_sequences = set(self.exclude_sequences)
eval_batches = [b for b in eval_batches if b[0][0] not in exclude_sequences]
logger.warning(
logger.warn(
f"Excluded eval batches by sequence: {old_len} -> {len(eval_batches)}"
)

View File

@@ -21,6 +21,8 @@ import logging
import warnings
from collections.abc import Mapping
from dataclasses import dataclass, field
from distutils.version import LooseVersion
from typing import Any, Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type
import torch
@@ -220,8 +222,7 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
+ "| 'bicubic' | 'linear' | 'area' | 'nearest-exact'"
)
# We assume PyTorch 1.11 and newer.
interpolate_has_antialias = True
interpolate_has_antialias = LooseVersion(torch.__version__) >= "1.11"
if antialias and not interpolate_has_antialias:
warnings.warn("Antialiased interpolation requires PyTorch 1.11+; ignoring")

View File

@@ -304,11 +304,11 @@ def _show_predictions(
assert isinstance(preds, list)
pred_all = []
# Randomly choose a subset of the rendered images, sort by order in the sequence
# Randomly choose a subset of the rendered images, sort by ordr in the sequence
n_samples = min(n_samples, len(preds))
pred_idx = sorted(random.sample(list(range(len(preds))), n_samples))
for predi in pred_idx:
# Make the concatenation for the same camera vertically
# Make the concatentation for the same camera vertically
pred_all.append(
torch.cat(
[
@@ -359,7 +359,7 @@ def _generate_prediction_videos(
vws = {}
for k in predicted_keys:
if k not in preds[0]:
logger.warning(f"Cannot generate video for prediction key '{k}'")
logger.warn(f"Cannot generate video for prediction key '{k}'")
continue
cache_dir = (
None

View File

@@ -23,13 +23,11 @@ class _ball_query(Function):
"""
@staticmethod
def forward(ctx, p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube):
def forward(ctx, p1, p2, lengths1, lengths2, K, radius):
"""
Arguments defintions the same as in the ball_query function
"""
idx, dists = _C.ball_query(
p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube
)
idx, dists = _C.ball_query(p1, p2, lengths1, lengths2, K, radius)
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
ctx.mark_non_differentiable(idx)
return dists, idx
@@ -51,7 +49,7 @@ class _ball_query(Function):
grad_p1, grad_p2 = _C.knn_points_backward(
p1, p2, lengths1, lengths2, idx, 2, grad_dists
)
return grad_p1, grad_p2, None, None, None, None, None
return grad_p1, grad_p2, None, None, None, None
def ball_query(
@@ -62,7 +60,6 @@ def ball_query(
K: int = 500,
radius: float = 0.2,
return_nn: bool = True,
skip_points_outside_cube: bool = False,
):
"""
Ball Query is an alternative to KNN. It can be
@@ -101,9 +98,6 @@ def ball_query(
within the radius
radius: the radius around each point within which the neighbors need to be located
return_nn: If set to True returns the K neighbor points in p2 for each point in p1.
skip_points_outside_cube: If set to True, reduce multiplications of float values
by not explicitly calculating distances to points that fall outside the
D-cube with side length (2*radius) centered at each point in p1.
Returns:
dists: Tensor of shape (N, P1, K) giving the squared distances to
@@ -140,9 +134,7 @@ def ball_query(
if lengths2 is None:
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
dists, idx = _ball_query.apply(
p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube
)
dists, idx = _ball_query.apply(p1, p2, lengths1, lengths2, K, radius)
# Gather the neighbors if needed
points_nn = masked_gather(p2, idx) if return_nn else None

View File

@@ -47,7 +47,8 @@ def laplacian(verts: torch.Tensor, edges: torch.Tensor) -> torch.Tensor:
# i.e. A[i, j] = 1 if (i,j) is an edge, or
# A[e0, e1] = 1 & A[e1, e0] = 1
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device)
A = torch.sparse_coo_tensor(idx, ones, (V, V), dtype=torch.float32)
# pyre-fixme[16]: Module `sparse` has no attribute `FloatTensor`.
A = torch.sparse.FloatTensor(idx, ones, (V, V))
# the sum of i-th row of A gives the degree of the i-th vertex
deg = torch.sparse.sum(A, dim=1).to_dense()
@@ -61,13 +62,15 @@ def laplacian(verts: torch.Tensor, edges: torch.Tensor) -> torch.Tensor:
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
deg1 = torch.where(deg1 > 0.0, 1.0 / deg1, deg1)
val = torch.cat([deg0, deg1])
L = torch.sparse_coo_tensor(idx, val, (V, V), dtype=torch.float32)
# pyre-fixme[16]: Module `sparse` has no attribute `FloatTensor`.
L = torch.sparse.FloatTensor(idx, val, (V, V))
# Then we add the diagonal values L[i, i] = -1.
idx = torch.arange(V, device=verts.device)
idx = torch.stack([idx, idx], dim=0)
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device)
L -= torch.sparse_coo_tensor(idx, ones, (V, V), dtype=torch.float32)
# pyre-fixme[16]: Module `sparse` has no attribute `FloatTensor`.
L -= torch.sparse.FloatTensor(idx, ones, (V, V))
return L
@@ -123,7 +126,8 @@ def cot_laplacian(
ii = faces[:, [1, 2, 0]]
jj = faces[:, [2, 0, 1]]
idx = torch.stack([ii, jj], dim=0).view(2, F * 3)
L = torch.sparse_coo_tensor(idx, cot.view(-1), (V, V), dtype=torch.float32)
# pyre-fixme[16]: Module `sparse` has no attribute `FloatTensor`.
L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))
# Make it symmetric; this means we are also setting
# L[v2, v1] = cota
@@ -163,7 +167,7 @@ def norm_laplacian(
v0, v1 = edge_verts[:, 0], edge_verts[:, 1]
# Side lengths of each edge, of shape (E,)
w01 = torch.reciprocal((v0 - v1).norm(dim=1) + eps)
w01 = 1.0 / ((v0 - v1).norm(dim=1) + eps)
# Construct a sparse matrix by basically doing:
# L[v0, v1] = w01
@@ -171,7 +175,8 @@ def norm_laplacian(
e01 = edges.t() # (2, E)
V = verts.shape[0]
L = torch.sparse_coo_tensor(e01, w01, (V, V), dtype=torch.float32)
# pyre-fixme[16]: Module `sparse` has no attribute `FloatTensor`.
L = torch.sparse.FloatTensor(e01, w01, (V, V))
L = L + L.t()
return L

View File

@@ -55,7 +55,6 @@ def sample_farthest_points(
N, P, D = points.shape
device = points.device
constant_length = lengths is None
# Validate inputs
if lengths is None:
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
@@ -66,9 +65,7 @@ def sample_farthest_points(
raise ValueError("A value in lengths was too large.")
# TODO: support providing K as a ratio of the total number of points instead of as an int
max_K = -1
if isinstance(K, int):
max_K = K
K = torch.full((N,), K, dtype=torch.int64, device=device)
elif isinstance(K, list):
K = torch.tensor(K, dtype=torch.int64, device=device)
@@ -85,19 +82,15 @@ def sample_farthest_points(
K = K.to(torch.int64)
# Generate the starting indices for sampling
start_idxs = torch.zeros_like(lengths)
if random_start_point:
if constant_length:
start_idxs = torch.randint(high=P, size=(N,), device=device)
else:
start_idxs = (lengths * torch.rand(lengths.size(), device=device)).to(
torch.int64
)
else:
start_idxs = torch.zeros_like(lengths)
for n in range(N):
# pyre-fixme[6]: For 1st param expected `int` but got `Tensor`.
start_idxs[n] = torch.randint(high=lengths[n], size=(1,)).item()
with torch.no_grad():
# pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`.
idx = _C.sample_farthest_points(points, lengths, K, start_idxs, max_K)
idx = _C.sample_farthest_points(points, lengths, K, start_idxs)
sampled_points = masked_gather(points, idx)
return sampled_points, idx

View File

@@ -160,10 +160,9 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator)
indices = q_abs.argmax(dim=-1, keepdim=True)
expand_dims = list(batch_dim) + [1, 4]
gather_indices = indices.unsqueeze(-1).expand(expand_dims)
out = torch.gather(quat_candidates, -2, gather_indices).squeeze(-2)
out = quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
].reshape(batch_dim + (4,))
return standardize_quaternion(out)
@@ -294,11 +293,10 @@ def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tenso
tait_bryan = i0 != i2
if tait_bryan:
central_angle = torch.asin(
torch.clamp(matrix[..., i0, i2], -1.0, 1.0)
* (-1.0 if i0 - i2 in [-1, 2] else 1.0)
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
)
else:
central_angle = torch.acos(torch.clamp(matrix[..., i0, i0], -1.0, 1.0))
central_angle = torch.acos(matrix[..., i0, i0])
o = (
_angle_from_tan(

View File

@@ -75,21 +75,6 @@ def get_extensions():
]
if os.name != "nt":
nvcc_args.append("-std=c++17")
# CUDA 13.0+ compatibility flags for pulsar.
# Starting with CUDA 13, __global__ function visibility changed.
# See: https://developer.nvidia.com/blog/
# cuda-c-compiler-updates-impacting-elf-visibility-and-linkage/
cuda_version = torch.version.cuda
if cuda_version is not None:
major = int(cuda_version.split(".")[0])
if major >= 13:
nvcc_args.extend(
[
"--device-entity-has-hidden-visibility=false",
"-static-global-template-stub=false",
]
)
if cub_home is None:
prefix = os.environ.get("CONDA_PREFIX", None)
if prefix is not None and os.path.isdir(prefix + "/include/cub"):
@@ -149,7 +134,7 @@ if os.getenv("PYTORCH3D_NO_NINJA", "0") == "1":
class BuildExtension(torch.utils.cpp_extension.BuildExtension):
def __init__(self, *args, **kwargs):
super().__init__(*args, use_ninja=False, **kwargs)
super().__init__(use_ninja=False, *args, **kwargs)
else:
BuildExtension = torch.utils.cpp_extension.BuildExtension

View File

@@ -1,56 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from itertools import product
import torch
from fvcore.common.benchmark import benchmark
from pytorch3d.ops.ball_query import ball_query
def ball_query_square(
N: int, P1: int, P2: int, D: int, K: int, radius: float, device: str
):
device = torch.device(device)
pts1 = torch.rand(N, P1, D, device=device)
pts2 = torch.rand(N, P2, D, device=device)
torch.cuda.synchronize()
def output():
ball_query(pts1, pts2, K=K, radius=radius, skip_points_outside_cube=True)
torch.cuda.synchronize()
return output
def bm_ball_query() -> None:
backends = ["cpu", "cuda:0"]
kwargs_list = []
Ns = [32]
P1s = [256]
P2s = [2**p for p in range(9, 20, 2)]
Ds = [3, 10]
Ks = [500]
Rs = [0.01, 0.1]
test_cases = product(Ns, P1s, P2s, Ds, Ks, Rs, backends)
for case in test_cases:
N, P1, P2, D, K, R, b = case
kwargs_list.append(
{"N": N, "P1": P1, "P2": P2, "D": D, "K": K, "radius": R, "device": b}
)
benchmark(
ball_query_square,
"BALLQUERY_SQUARE",
kwargs_list,
num_iters=30,
warmup_iters=1,
)
if __name__ == "__main__":
bm_ball_query()

View File

@@ -31,13 +31,6 @@ def skip_opengl_requested() -> bool:
usesOpengl = unittest.skipIf(skip_opengl_requested(), "uses opengl")
def have_multiple_gpus() -> bool:
return torch.cuda.device_count() > 1
needs_multigpu = unittest.skipIf(not have_multiple_gpus(), "needs multiple GPUs")
def get_tests_dir() -> Path:
"""
Returns Path for the directory containing this file.

View File

@@ -15,7 +15,7 @@ from tests.common_testing import get_pytorch3d_dir
# This file groups together tests which look at the code without running it.
class TestBuild(unittest.TestCase):
def _test_no_import_cycles(self):
def test_no_import_cycles(self):
# Check each module of pytorch3d imports cleanly,
# which may fail if there are import cycles.

View File

@@ -78,7 +78,7 @@ class TestBuild(unittest.TestCase):
self.assertListEqual(sorted(listed_in_json), notes_on_disk)
def _test_no_import_cycles(self):
def test_no_import_cycles(self):
# Check each module of pytorch3d imports cleanly,
# which may fail if there are import cycles.

View File

@@ -72,7 +72,6 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
factors = [Ns, Ds, P1s, P2s, Ks, norms]
for N, D, P1, P2, K, norm in product(*factors):
for version in versions:
torch.manual_seed(2)
if version == 3 and K > 4:
continue
x = torch.randn(N, P1, D, device=device, requires_grad=True)

View File

@@ -703,6 +703,80 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
self.assertEqual(cuda_device, cloud.device)
self.assertIsNot(cloud, converted_cloud)
def test_to_list(self):
cloud = self.init_cloud(5, 100, 10)
device = torch.device("cuda:1")
new_cloud = cloud.to(device)
self.assertTrue(new_cloud.device == device)
self.assertTrue(cloud.device == torch.device("cuda:0"))
for attrib in [
"points_padded",
"points_packed",
"normals_padded",
"normals_packed",
"features_padded",
"features_packed",
"num_points_per_cloud",
"cloud_to_packed_first_idx",
"padded_to_packed_idx",
]:
self.assertClose(
getattr(new_cloud, attrib)().cpu(), getattr(cloud, attrib)().cpu()
)
for i in range(len(cloud)):
self.assertClose(
cloud.points_list()[i].cpu(), new_cloud.points_list()[i].cpu()
)
self.assertClose(
cloud.normals_list()[i].cpu(), new_cloud.normals_list()[i].cpu()
)
self.assertClose(
cloud.features_list()[i].cpu(), new_cloud.features_list()[i].cpu()
)
self.assertTrue(all(cloud.valid.cpu() == new_cloud.valid.cpu()))
self.assertTrue(cloud.equisized == new_cloud.equisized)
self.assertTrue(cloud._N == new_cloud._N)
self.assertTrue(cloud._P == new_cloud._P)
self.assertTrue(cloud._C == new_cloud._C)
def test_to_tensor(self):
cloud = self.init_cloud(5, 100, 10, lists_to_tensors=True)
device = torch.device("cuda:1")
new_cloud = cloud.to(device)
self.assertTrue(new_cloud.device == device)
self.assertTrue(cloud.device == torch.device("cuda:0"))
for attrib in [
"points_padded",
"points_packed",
"normals_padded",
"normals_packed",
"features_padded",
"features_packed",
"num_points_per_cloud",
"cloud_to_packed_first_idx",
"padded_to_packed_idx",
]:
self.assertClose(
getattr(new_cloud, attrib)().cpu(), getattr(cloud, attrib)().cpu()
)
for i in range(len(cloud)):
self.assertClose(
cloud.points_list()[i].cpu(), new_cloud.points_list()[i].cpu()
)
self.assertClose(
cloud.normals_list()[i].cpu(), new_cloud.normals_list()[i].cpu()
)
self.assertClose(
cloud.features_list()[i].cpu(), new_cloud.features_list()[i].cpu()
)
self.assertTrue(all(cloud.valid.cpu() == new_cloud.valid.cpu()))
self.assertTrue(cloud.equisized == new_cloud.equisized)
self.assertTrue(cloud._N == new_cloud._N)
self.assertTrue(cloud._P == new_cloud._P)
self.assertTrue(cloud._C == new_cloud._C)
def test_split(self):
clouds = self.init_cloud(5, 100, 10)
split_sizes = [2, 3]

View File

@@ -1,166 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import numpy as np
import torch
from pytorch3d.structures.pointclouds import Pointclouds
from .common_testing import needs_multigpu, TestCaseMixin
class TestPointclouds(TestCaseMixin, unittest.TestCase):
def setUp(self) -> None:
np.random.seed(42)
torch.manual_seed(42)
@staticmethod
def init_cloud(
num_clouds: int = 3,
max_points: int = 100,
channels: int = 4,
lists_to_tensors: bool = False,
with_normals: bool = True,
with_features: bool = True,
min_points: int = 0,
requires_grad: bool = False,
):
"""
Function to generate a Pointclouds object of N meshes with
random number of points.
Args:
num_clouds: Number of clouds to generate.
channels: Number of features.
max_points: Max number of points per cloud.
lists_to_tensors: Determines whether the generated clouds should be
constructed from lists (=False) or
tensors (=True) of points/normals/features.
with_normals: bool whether to include normals
with_features: bool whether to include features
min_points: Min number of points per cloud
Returns:
Pointclouds object.
"""
device = torch.device("cuda:0")
p = torch.randint(low=min_points, high=max_points, size=(num_clouds,))
if lists_to_tensors:
p.fill_(p[0])
points_list = [
torch.rand(
(i, 3), device=device, dtype=torch.float32, requires_grad=requires_grad
)
for i in p
]
normals_list, features_list = None, None
if with_normals:
normals_list = [
torch.rand(
(i, 3),
device=device,
dtype=torch.float32,
requires_grad=requires_grad,
)
for i in p
]
if with_features:
features_list = [
torch.rand(
(i, channels),
device=device,
dtype=torch.float32,
requires_grad=requires_grad,
)
for i in p
]
if lists_to_tensors:
points_list = torch.stack(points_list)
if with_normals:
normals_list = torch.stack(normals_list)
if with_features:
features_list = torch.stack(features_list)
return Pointclouds(points_list, normals=normals_list, features=features_list)
@needs_multigpu
def test_to_list(self):
cloud = self.init_cloud(5, 100, 10)
device = torch.device("cuda:1")
new_cloud = cloud.to(device)
self.assertTrue(new_cloud.device == device)
self.assertTrue(cloud.device == torch.device("cuda:0"))
for attrib in [
"points_padded",
"points_packed",
"normals_padded",
"normals_packed",
"features_padded",
"features_packed",
"num_points_per_cloud",
"cloud_to_packed_first_idx",
"padded_to_packed_idx",
]:
self.assertClose(
getattr(new_cloud, attrib)().cpu(), getattr(cloud, attrib)().cpu()
)
for i in range(len(cloud)):
self.assertClose(
cloud.points_list()[i].cpu(), new_cloud.points_list()[i].cpu()
)
self.assertClose(
cloud.normals_list()[i].cpu(), new_cloud.normals_list()[i].cpu()
)
self.assertClose(
cloud.features_list()[i].cpu(), new_cloud.features_list()[i].cpu()
)
self.assertTrue(all(cloud.valid.cpu() == new_cloud.valid.cpu()))
self.assertTrue(cloud.equisized == new_cloud.equisized)
self.assertTrue(cloud._N == new_cloud._N)
self.assertTrue(cloud._P == new_cloud._P)
self.assertTrue(cloud._C == new_cloud._C)
@needs_multigpu
def test_to_tensor(self):
cloud = self.init_cloud(5, 100, 10, lists_to_tensors=True)
device = torch.device("cuda:1")
new_cloud = cloud.to(device)
self.assertTrue(new_cloud.device == device)
self.assertTrue(cloud.device == torch.device("cuda:0"))
for attrib in [
"points_padded",
"points_packed",
"normals_padded",
"normals_packed",
"features_padded",
"features_packed",
"num_points_per_cloud",
"cloud_to_packed_first_idx",
"padded_to_packed_idx",
]:
self.assertClose(
getattr(new_cloud, attrib)().cpu(), getattr(cloud, attrib)().cpu()
)
for i in range(len(cloud)):
self.assertClose(
cloud.points_list()[i].cpu(), new_cloud.points_list()[i].cpu()
)
self.assertClose(
cloud.normals_list()[i].cpu(), new_cloud.normals_list()[i].cpu()
)
self.assertClose(
cloud.features_list()[i].cpu(), new_cloud.features_list()[i].cpu()
)
self.assertTrue(all(cloud.valid.cpu() == new_cloud.valid.cpu()))
self.assertTrue(cloud.equisized == new_cloud.equisized)
self.assertTrue(cloud._N == new_cloud._N)
self.assertTrue(cloud._P == new_cloud._P)
self.assertTrue(cloud._C == new_cloud._C)

View File

@@ -165,7 +165,7 @@ class TestICP(TestCaseMixin, unittest.TestCase):
a set of randomly-sized Pointclouds and on their padded versions.
"""
torch.manual_seed(14)
torch.manual_seed(4)
device = torch.device("cuda:0")
for estimate_scale in (True, False):

View File

@@ -29,7 +29,7 @@ from pytorch3d.renderer.opengl import MeshRasterizerOpenGL
from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.utils.ico_sphere import ico_sphere
from .common_testing import needs_multigpu, TestCaseMixin, usesOpengl
from .common_testing import TestCaseMixin, usesOpengl
# Set the number of GPUS you want to test with
@@ -116,7 +116,6 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
output_images = renderer(mesh)
self.assertEqual(output_images.device, device2)
@needs_multigpu
def test_mesh_renderer_to(self):
self._mesh_renderer_to(MeshRasterizer, SoftPhongShader)
@@ -174,7 +173,6 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
for _ in range(100):
model(verts, texs)
@needs_multigpu
def test_render_meshes(self):
self._render_meshes(MeshRasterizer, HardGouraudShader)

View File

@@ -63,6 +63,9 @@ class TestTensorProperties(TestCaseMixin, unittest.TestCase):
self.assertEqual(example_gpu.device.type, "cuda")
self.assertIsNotNone(example_gpu.device.index)
example_gpu1 = example.cuda(1)
self.assertEqual(example_gpu1.device, torch.device("cuda:1"))
def test_clone(self):
# Check clone method
example = TensorPropertiesTestClass(x=10.0, y=(100.0, 200.0))

View File

@@ -8,6 +8,7 @@
import itertools
import math
import unittest
from distutils.version import LooseVersion
from typing import Optional, Union
import numpy as np
@@ -270,6 +271,7 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6
)
@unittest.skipIf(LooseVersion(torch.__version__) < "1.9", "recent torchscript only")
def test_scriptable(self):
torch.jit.script(axis_angle_to_matrix)
torch.jit.script(axis_angle_to_quaternion)

View File

@@ -7,6 +7,7 @@
import math
import unittest
from distutils.version import LooseVersion
import numpy as np
import torch
@@ -254,6 +255,7 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
# all grad values have to be finite
self.assertTrue(torch.isfinite(r.grad).all())
@unittest.skipIf(LooseVersion(torch.__version__) < "1.9", "recent torchscript only")
def test_scriptable(self):
torch.jit.script(so3_exp_map)
torch.jit.script(so3_log_map)