mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 22:30:35 +08:00
Compare commits
25 Commits
bottler/un
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f5f6b78e70 | ||
|
|
33824be3cb | ||
|
|
2d4d345b6f | ||
|
|
45df20e9e2 | ||
|
|
fc6a6b8951 | ||
|
|
7711bf34a8 | ||
|
|
d098beb7a7 | ||
|
|
dd068703d1 | ||
|
|
50f8efa1cb | ||
|
|
5043d15361 | ||
|
|
e3d3a67a89 | ||
|
|
e55ea90609 | ||
|
|
3aee2a6005 | ||
|
|
c5ea8fa49e | ||
|
|
3ff6c5ab85 | ||
|
|
267bd8ef87 | ||
|
|
177eec6378 | ||
|
|
71db7a0ea2 | ||
|
|
6020323d94 | ||
|
|
182e845c19 | ||
|
|
f315ac131b | ||
|
|
fc08621879 | ||
|
|
3f327a516b | ||
|
|
366eff21d9 | ||
|
|
0a59450f0e |
@@ -10,7 +10,7 @@
|
||||
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
|
||||
DIR=$(dirname "${DIR}")
|
||||
|
||||
if [[ -f "${DIR}/TARGETS" ]]
|
||||
if [[ -f "${DIR}/BUCK" ]]
|
||||
then
|
||||
pyfmt "${DIR}"
|
||||
else
|
||||
|
||||
@@ -6,4 +6,4 @@
|
||||
|
||||
# pyre-unsafe
|
||||
|
||||
__version__ = "0.7.8"
|
||||
__version__ = "0.7.9"
|
||||
|
||||
@@ -32,7 +32,9 @@ __global__ void BallQueryKernel(
|
||||
at::PackedTensorAccessor64<int64_t, 3, at::RestrictPtrTraits> idxs,
|
||||
at::PackedTensorAccessor64<scalar_t, 3, at::RestrictPtrTraits> dists,
|
||||
const int64_t K,
|
||||
const float radius2) {
|
||||
const float radius,
|
||||
const float radius2,
|
||||
const bool skip_points_outside_cube) {
|
||||
const int64_t N = p1.size(0);
|
||||
const int64_t chunks_per_cloud = (1 + (p1.size(1) - 1) / blockDim.x);
|
||||
const int64_t chunks_to_do = N * chunks_per_cloud;
|
||||
@@ -51,7 +53,19 @@ __global__ void BallQueryKernel(
|
||||
// Iterate over points in p2 until desired count is reached or
|
||||
// all points have been considered
|
||||
for (int64_t j = 0, count = 0; j < lengths2[n] && count < K; ++j) {
|
||||
// Calculate the distance between the points
|
||||
if (skip_points_outside_cube) {
|
||||
bool is_within_radius = true;
|
||||
// Filter when any one coordinate is already outside the radius
|
||||
for (int d = 0; is_within_radius && d < D; ++d) {
|
||||
scalar_t abs_diff = fabs(p1[n][i][d] - p2[n][j][d]);
|
||||
is_within_radius = (abs_diff <= radius);
|
||||
}
|
||||
if (!is_within_radius) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Else, calculate the distance between the points and compare
|
||||
scalar_t dist2 = 0.0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
scalar_t diff = p1[n][i][d] - p2[n][j][d];
|
||||
@@ -77,7 +91,8 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
||||
const at::Tensor& lengths1, // (N,)
|
||||
const at::Tensor& lengths2, // (N,)
|
||||
int K,
|
||||
float radius) {
|
||||
float radius,
|
||||
bool skip_points_outside_cube) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
|
||||
lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
|
||||
@@ -120,7 +135,9 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
||||
idxs.packed_accessor64<int64_t, 3, at::RestrictPtrTraits>(),
|
||||
dists.packed_accessor64<float, 3, at::RestrictPtrTraits>(),
|
||||
K_64,
|
||||
radius2);
|
||||
radius,
|
||||
radius2,
|
||||
skip_points_outside_cube);
|
||||
}));
|
||||
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
@@ -25,6 +25,9 @@
|
||||
// within the radius
|
||||
// radius: the radius around each point within which the neighbors need to be
|
||||
// located
|
||||
// skip_points_outside_cube: If true, reduce multiplications of float values
|
||||
// by not explicitly calculating distances to points that fall outside the
|
||||
// D-cube with side length (2*radius) centered at each point in p1.
|
||||
//
|
||||
// Returns:
|
||||
// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
|
||||
@@ -46,7 +49,8 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const int K,
|
||||
const float radius);
|
||||
const float radius,
|
||||
const bool skip_points_outside_cube);
|
||||
|
||||
// CUDA implementation
|
||||
std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
||||
@@ -55,7 +59,8 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCuda(
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
const int K,
|
||||
const float radius);
|
||||
const float radius,
|
||||
const bool skip_points_outside_cube);
|
||||
|
||||
// Implementation which is exposed
|
||||
// Note: the backward pass reuses the KNearestNeighborBackward kernel
|
||||
@@ -65,7 +70,8 @@ inline std::tuple<at::Tensor, at::Tensor> BallQuery(
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K,
|
||||
float radius) {
|
||||
float radius,
|
||||
bool skip_points_outside_cube) {
|
||||
if (p1.is_cuda() || p2.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(p1);
|
||||
@@ -76,16 +82,20 @@ inline std::tuple<at::Tensor, at::Tensor> BallQuery(
|
||||
lengths1.contiguous(),
|
||||
lengths2.contiguous(),
|
||||
K,
|
||||
radius);
|
||||
radius,
|
||||
skip_points_outside_cube);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(p1);
|
||||
CHECK_CPU(p2);
|
||||
return BallQueryCpu(
|
||||
p1.contiguous(),
|
||||
p2.contiguous(),
|
||||
lengths1.contiguous(),
|
||||
lengths2.contiguous(),
|
||||
K,
|
||||
radius);
|
||||
radius,
|
||||
skip_points_outside_cube);
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include <math.h>
|
||||
#include <torch/extension.h>
|
||||
#include <tuple>
|
||||
|
||||
@@ -15,7 +16,8 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||
const at::Tensor& lengths1,
|
||||
const at::Tensor& lengths2,
|
||||
int K,
|
||||
float radius) {
|
||||
float radius,
|
||||
bool skip_points_outside_cube) {
|
||||
const int N = p1.size(0);
|
||||
const int P1 = p1.size(1);
|
||||
const int D = p1.size(2);
|
||||
@@ -37,6 +39,16 @@ std::tuple<at::Tensor, at::Tensor> BallQueryCpu(
|
||||
const int64_t length2 = lengths2_a[n];
|
||||
for (int64_t i = 0; i < length1; ++i) {
|
||||
for (int64_t j = 0, count = 0; j < length2 && count < K; ++j) {
|
||||
if (skip_points_outside_cube) {
|
||||
bool is_within_radius = true;
|
||||
for (int d = 0; is_within_radius && d < D; ++d) {
|
||||
float abs_diff = fabs(p1_a[n][i][d] - p2_a[n][j][d]);
|
||||
is_within_radius = (abs_diff <= radius);
|
||||
}
|
||||
if (!is_within_radius) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
float dist2 = 0;
|
||||
for (int d = 0; d < D; ++d) {
|
||||
float diff = p1_a[n][i][d] - p2_a[n][j][d];
|
||||
|
||||
@@ -98,6 +98,11 @@ at::Tensor SigmoidAlphaBlendBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(distances);
|
||||
CHECK_CPU(pix_to_face);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(grad_alphas);
|
||||
|
||||
return SigmoidAlphaBlendBackwardCpu(
|
||||
grad_alphas, alphas, distances, pix_to_face, sigma);
|
||||
}
|
||||
|
||||
@@ -74,6 +74,9 @@ torch::Tensor alphaCompositeForward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
return alphaCompositeCpuForward(features, alphas, points_idx);
|
||||
}
|
||||
}
|
||||
@@ -101,6 +104,11 @@ std::tuple<torch::Tensor, torch::Tensor> alphaCompositeBackward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(grad_outputs);
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
|
||||
return alphaCompositeCpuBackward(
|
||||
grad_outputs, features, alphas, points_idx);
|
||||
}
|
||||
|
||||
@@ -73,6 +73,10 @@ torch::Tensor weightedSumNormForward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
|
||||
return weightedSumNormCpuForward(features, alphas, points_idx);
|
||||
}
|
||||
}
|
||||
@@ -100,6 +104,11 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumNormBackward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(grad_outputs);
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
|
||||
return weightedSumNormCpuBackward(
|
||||
grad_outputs, features, alphas, points_idx);
|
||||
}
|
||||
|
||||
@@ -72,6 +72,9 @@ torch::Tensor weightedSumForward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
return weightedSumCpuForward(features, alphas, points_idx);
|
||||
}
|
||||
}
|
||||
@@ -98,6 +101,11 @@ std::tuple<torch::Tensor, torch::Tensor> weightedSumBackward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(grad_outputs);
|
||||
CHECK_CPU(features);
|
||||
CHECK_CPU(alphas);
|
||||
CHECK_CPU(points_idx);
|
||||
|
||||
return weightedSumCpuBackward(grad_outputs, features, alphas, points_idx);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
|
||||
// clang-format off
|
||||
#include "./pulsar/global.h" // Include before <torch/extension.h>.
|
||||
#include <torch/extension.h>
|
||||
// clang-format on
|
||||
#include "./pulsar/pytorch/renderer.h"
|
||||
#include "./pulsar/pytorch/tensor_util.h"
|
||||
@@ -106,15 +105,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
py::class_<
|
||||
pulsar::pytorch::Renderer,
|
||||
std::shared_ptr<pulsar::pytorch::Renderer>>(m, "PulsarRenderer")
|
||||
.def(py::init<
|
||||
const uint&,
|
||||
const uint&,
|
||||
const uint&,
|
||||
const bool&,
|
||||
const bool&,
|
||||
const float&,
|
||||
const uint&,
|
||||
const uint&>())
|
||||
.def(
|
||||
py::init<
|
||||
const uint&,
|
||||
const uint&,
|
||||
const uint&,
|
||||
const bool&,
|
||||
const bool&,
|
||||
const float&,
|
||||
const uint&,
|
||||
const uint&>())
|
||||
.def(
|
||||
"__eq__",
|
||||
[](const pulsar::pytorch::Renderer& a,
|
||||
|
||||
@@ -60,6 +60,8 @@ std::tuple<at::Tensor, at::Tensor> FaceAreasNormalsForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(verts);
|
||||
CHECK_CPU(faces);
|
||||
return FaceAreasNormalsForwardCpu(verts, faces);
|
||||
}
|
||||
|
||||
@@ -80,5 +82,9 @@ at::Tensor FaceAreasNormalsBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(grad_areas);
|
||||
CHECK_CPU(grad_normals);
|
||||
CHECK_CPU(verts);
|
||||
CHECK_CPU(faces);
|
||||
return FaceAreasNormalsBackwardCpu(grad_areas, grad_normals, verts, faces);
|
||||
}
|
||||
|
||||
@@ -53,5 +53,7 @@ at::Tensor GatherScatter(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(input);
|
||||
CHECK_CPU(edges);
|
||||
return GatherScatterCpu(input, edges, directed, backward);
|
||||
}
|
||||
|
||||
@@ -57,6 +57,8 @@ at::Tensor InterpFaceAttrsForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(face_attrs);
|
||||
CHECK_CPU(barycentric_coords);
|
||||
return InterpFaceAttrsForwardCpu(pix_to_face, barycentric_coords, face_attrs);
|
||||
}
|
||||
|
||||
@@ -106,6 +108,9 @@ std::tuple<at::Tensor, at::Tensor> InterpFaceAttrsBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(face_attrs);
|
||||
CHECK_CPU(barycentric_coords);
|
||||
CHECK_CPU(grad_pix_attrs);
|
||||
return InterpFaceAttrsBackwardCpu(
|
||||
pix_to_face, barycentric_coords, face_attrs, grad_pix_attrs);
|
||||
}
|
||||
|
||||
@@ -44,5 +44,7 @@ inline std::tuple<at::Tensor, at::Tensor> IoUBox3D(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(boxes1);
|
||||
CHECK_CPU(boxes2);
|
||||
return IoUBox3DCpu(boxes1.contiguous(), boxes2.contiguous());
|
||||
}
|
||||
|
||||
@@ -74,6 +74,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborIdx(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(p1);
|
||||
CHECK_CPU(p2);
|
||||
return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K);
|
||||
}
|
||||
|
||||
@@ -140,6 +142,8 @@ std::tuple<at::Tensor, at::Tensor> KNearestNeighborBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(p1);
|
||||
CHECK_CPU(p2);
|
||||
return KNearestNeighborBackwardCpu(
|
||||
p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
|
||||
}
|
||||
|
||||
@@ -58,5 +58,6 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubes(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(vol);
|
||||
return MarchingCubesCpu(vol.contiguous(), isolevel);
|
||||
}
|
||||
|
||||
@@ -88,6 +88,8 @@ at::Tensor PackedToPadded(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(inputs_packed);
|
||||
CHECK_CPU(first_idxs);
|
||||
return PackedToPaddedCpu(inputs_packed, first_idxs, max_size);
|
||||
}
|
||||
|
||||
@@ -105,5 +107,7 @@ at::Tensor PaddedToPacked(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(inputs_padded);
|
||||
CHECK_CPU(first_idxs);
|
||||
return PaddedToPackedCpu(inputs_padded, first_idxs, num_inputs);
|
||||
}
|
||||
|
||||
@@ -174,8 +174,8 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceForwardCpu(
|
||||
at::Tensor idxs = at::zeros({A_N,}, as_first_idx.options());
|
||||
// clang-format on
|
||||
|
||||
auto as_a = as.accessor < float, H1 == 1 ? 2 : 3 > ();
|
||||
auto bs_a = bs.accessor < float, H2 == 1 ? 2 : 3 > ();
|
||||
auto as_a = as.accessor<float, H1 == 1 ? 2 : 3>();
|
||||
auto bs_a = bs.accessor<float, H2 == 1 ? 2 : 3>();
|
||||
auto as_first_idx_a = as_first_idx.accessor<int64_t, 1>();
|
||||
auto bs_first_idx_a = bs_first_idx.accessor<int64_t, 1>();
|
||||
auto dists_a = dists.accessor<float, 1>();
|
||||
@@ -230,10 +230,10 @@ std::tuple<at::Tensor, at::Tensor> HullHullDistanceBackwardCpu(
|
||||
at::Tensor grad_as = at::zeros_like(as);
|
||||
at::Tensor grad_bs = at::zeros_like(bs);
|
||||
|
||||
auto as_a = as.accessor < float, H1 == 1 ? 2 : 3 > ();
|
||||
auto bs_a = bs.accessor < float, H2 == 1 ? 2 : 3 > ();
|
||||
auto grad_as_a = grad_as.accessor < float, H1 == 1 ? 2 : 3 > ();
|
||||
auto grad_bs_a = grad_bs.accessor < float, H2 == 1 ? 2 : 3 > ();
|
||||
auto as_a = as.accessor<float, H1 == 1 ? 2 : 3>();
|
||||
auto bs_a = bs.accessor<float, H2 == 1 ? 2 : 3>();
|
||||
auto grad_as_a = grad_as.accessor<float, H1 == 1 ? 2 : 3>();
|
||||
auto grad_bs_a = grad_bs.accessor<float, H2 == 1 ? 2 : 3>();
|
||||
auto idx_bs_a = idx_bs.accessor<int64_t, 1>();
|
||||
auto grad_dists_a = grad_dists.accessor<float, 1>();
|
||||
|
||||
|
||||
@@ -88,6 +88,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(points_first_idx);
|
||||
CHECK_CPU(tris);
|
||||
CHECK_CPU(tris_first_idx);
|
||||
return PointFaceDistanceForwardCpu(
|
||||
points, points_first_idx, tris, tris_first_idx, min_triangle_area);
|
||||
}
|
||||
@@ -143,6 +147,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(tris);
|
||||
CHECK_CPU(idx_points);
|
||||
CHECK_CPU(grad_dists);
|
||||
return PointFaceDistanceBackwardCpu(
|
||||
points, tris, idx_points, grad_dists, min_triangle_area);
|
||||
}
|
||||
@@ -221,6 +229,10 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(points_first_idx);
|
||||
CHECK_CPU(tris);
|
||||
CHECK_CPU(tris_first_idx);
|
||||
return FacePointDistanceForwardCpu(
|
||||
points, points_first_idx, tris, tris_first_idx, min_triangle_area);
|
||||
}
|
||||
@@ -277,6 +289,10 @@ std::tuple<torch::Tensor, torch::Tensor> FacePointDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(tris);
|
||||
CHECK_CPU(idx_tris);
|
||||
CHECK_CPU(grad_dists);
|
||||
return FacePointDistanceBackwardCpu(
|
||||
points, tris, idx_tris, grad_dists, min_triangle_area);
|
||||
}
|
||||
@@ -346,6 +362,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(points_first_idx);
|
||||
CHECK_CPU(segms);
|
||||
CHECK_CPU(segms_first_idx);
|
||||
return PointEdgeDistanceForwardCpu(
|
||||
points, points_first_idx, segms, segms_first_idx, max_points);
|
||||
}
|
||||
@@ -396,6 +416,10 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(segms);
|
||||
CHECK_CPU(idx_points);
|
||||
CHECK_CPU(grad_dists);
|
||||
return PointEdgeDistanceBackwardCpu(points, segms, idx_points, grad_dists);
|
||||
}
|
||||
|
||||
@@ -464,6 +488,10 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(points_first_idx);
|
||||
CHECK_CPU(segms);
|
||||
CHECK_CPU(segms_first_idx);
|
||||
return EdgePointDistanceForwardCpu(
|
||||
points, points_first_idx, segms, segms_first_idx, max_segms);
|
||||
}
|
||||
@@ -514,6 +542,10 @@ std::tuple<torch::Tensor, torch::Tensor> EdgePointDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(segms);
|
||||
CHECK_CPU(idx_segms);
|
||||
CHECK_CPU(grad_dists);
|
||||
return EdgePointDistanceBackwardCpu(points, segms, idx_segms, grad_dists);
|
||||
}
|
||||
|
||||
@@ -567,6 +599,8 @@ torch::Tensor PointFaceArrayDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(tris);
|
||||
return PointFaceArrayDistanceForwardCpu(points, tris, min_triangle_area);
|
||||
}
|
||||
|
||||
@@ -613,6 +647,9 @@ std::tuple<torch::Tensor, torch::Tensor> PointFaceArrayDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(tris);
|
||||
CHECK_CPU(grad_dists);
|
||||
return PointFaceArrayDistanceBackwardCpu(
|
||||
points, tris, grad_dists, min_triangle_area);
|
||||
}
|
||||
@@ -661,6 +698,8 @@ torch::Tensor PointEdgeArrayDistanceForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(segms);
|
||||
return PointEdgeArrayDistanceForwardCpu(points, segms);
|
||||
}
|
||||
|
||||
@@ -703,5 +742,8 @@ std::tuple<torch::Tensor, torch::Tensor> PointEdgeArrayDistanceBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(segms);
|
||||
CHECK_CPU(grad_dists);
|
||||
return PointEdgeArrayDistanceBackwardCpu(points, segms, grad_dists);
|
||||
}
|
||||
|
||||
@@ -104,6 +104,12 @@ inline void PointsToVolumesForward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points_3d);
|
||||
CHECK_CPU(points_features);
|
||||
CHECK_CPU(volume_densities);
|
||||
CHECK_CPU(volume_features);
|
||||
CHECK_CPU(grid_sizes);
|
||||
CHECK_CPU(mask);
|
||||
PointsToVolumesForwardCpu(
|
||||
points_3d,
|
||||
points_features,
|
||||
@@ -183,6 +189,14 @@ inline void PointsToVolumesBackward(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points_3d);
|
||||
CHECK_CPU(points_features);
|
||||
CHECK_CPU(grid_sizes);
|
||||
CHECK_CPU(mask);
|
||||
CHECK_CPU(grad_volume_densities);
|
||||
CHECK_CPU(grad_volume_features);
|
||||
CHECK_CPU(grad_points_3d);
|
||||
CHECK_CPU(grad_points_features);
|
||||
PointsToVolumesBackwardCpu(
|
||||
points_3d,
|
||||
points_features,
|
||||
|
||||
@@ -15,8 +15,8 @@
|
||||
#endif
|
||||
|
||||
#if defined(_WIN64) || defined(_WIN32)
|
||||
#define uint unsigned int
|
||||
#define ushort unsigned short
|
||||
using uint = unsigned int;
|
||||
using ushort = unsigned short;
|
||||
#endif
|
||||
|
||||
#include "./logging.h" // <- include before torch/extension.h
|
||||
|
||||
@@ -417,7 +417,7 @@ __device__ static float atomicMin(float* address, float val) {
|
||||
(OUT_PTR), \
|
||||
(NUM_SELECTED_PTR), \
|
||||
(NUM_ITEMS), \
|
||||
stream = (STREAM));
|
||||
(STREAM));
|
||||
|
||||
#define COPY_HOST_DEV(PTR_D, PTR_H, TYPE, SIZE) \
|
||||
HANDLECUDA(cudaMemcpy( \
|
||||
|
||||
@@ -357,11 +357,11 @@ void MAX_WS(
|
||||
//
|
||||
//
|
||||
#define END_PARALLEL() \
|
||||
end_parallel :; \
|
||||
end_parallel:; \
|
||||
}
|
||||
#define END_PARALLEL_NORET() }
|
||||
#define END_PARALLEL_2D() \
|
||||
end_parallel :; \
|
||||
end_parallel:; \
|
||||
} \
|
||||
}
|
||||
#define END_PARALLEL_2D_NORET() \
|
||||
|
||||
@@ -70,11 +70,6 @@ struct CamGradInfo {
|
||||
float3 pixel_dir_y;
|
||||
};
|
||||
|
||||
// TODO: remove once https://github.com/NVlabs/cub/issues/172 is resolved.
|
||||
struct IntWrapper {
|
||||
int val;
|
||||
};
|
||||
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
||||
|
||||
@@ -149,11 +149,6 @@ IHD CamGradInfo operator*(const CamGradInfo& a, const float& b) {
|
||||
return res;
|
||||
}
|
||||
|
||||
IHD IntWrapper operator+(const IntWrapper& a, const IntWrapper& b) {
|
||||
IntWrapper res;
|
||||
res.val = a.val + b.val;
|
||||
return res;
|
||||
}
|
||||
} // namespace pulsar
|
||||
|
||||
#endif
|
||||
|
||||
@@ -155,8 +155,8 @@ void backward(
|
||||
stream);
|
||||
CHECKLAUNCH();
|
||||
SUM_WS(
|
||||
(IntWrapper*)(self->ids_sorted_d),
|
||||
(IntWrapper*)(self->n_grad_contributions_d),
|
||||
self->ids_sorted_d,
|
||||
self->n_grad_contributions_d,
|
||||
static_cast<int>(num_balls),
|
||||
self->workspace_d,
|
||||
self->workspace_size,
|
||||
|
||||
@@ -52,7 +52,7 @@ HOST void construct(
|
||||
self->cam.film_width = width;
|
||||
self->cam.film_height = height;
|
||||
self->max_num_balls = max_num_balls;
|
||||
MALLOC(self->result_d, float, width* height* n_channels);
|
||||
MALLOC(self->result_d, float, width * height * n_channels);
|
||||
self->cam.orthogonal_projection = orthogonal_projection;
|
||||
self->cam.right_handed = right_handed_system;
|
||||
self->cam.background_normalization_depth = background_normalization_depth;
|
||||
@@ -93,7 +93,7 @@ HOST void construct(
|
||||
MALLOC(self->di_sorted_d, DrawInfo, max_num_balls);
|
||||
MALLOC(self->region_flags_d, char, max_num_balls);
|
||||
MALLOC(self->num_selected_d, size_t, 1);
|
||||
MALLOC(self->forw_info_d, float, width* height * (3 + 2 * n_track));
|
||||
MALLOC(self->forw_info_d, float, width * height * (3 + 2 * n_track));
|
||||
MALLOC(self->min_max_pixels_d, IntersectInfo, 1);
|
||||
MALLOC(self->grad_pos_d, float3, max_num_balls);
|
||||
MALLOC(self->grad_col_d, float, max_num_balls* n_channels);
|
||||
|
||||
@@ -255,7 +255,7 @@ GLOBAL void calc_signature(
|
||||
* for every iteration through the loading loop every thread could add a
|
||||
* 'hit' to the buffer.
|
||||
*/
|
||||
#define RENDER_BUFFER_SIZE RENDER_BLOCK_SIZE* RENDER_BLOCK_SIZE * 2
|
||||
#define RENDER_BUFFER_SIZE RENDER_BLOCK_SIZE * RENDER_BLOCK_SIZE * 2
|
||||
/**
|
||||
* The threshold after which the spheres that are in the render buffer
|
||||
* are rendered and the buffer is flushed.
|
||||
|
||||
@@ -6,9 +6,6 @@
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
#include "./global.h"
|
||||
#include "./logging.h"
|
||||
|
||||
/**
|
||||
* A compilation unit to provide warnings about the code and avoid
|
||||
* repeated messages.
|
||||
|
||||
@@ -138,6 +138,9 @@ RasterizeMeshesNaive(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(face_verts);
|
||||
CHECK_CPU(mesh_to_face_first_idx);
|
||||
CHECK_CPU(num_faces_per_mesh);
|
||||
return RasterizeMeshesNaiveCpu(
|
||||
face_verts,
|
||||
mesh_to_face_first_idx,
|
||||
@@ -232,6 +235,11 @@ torch::Tensor RasterizeMeshesBackward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(face_verts);
|
||||
CHECK_CPU(pix_to_face);
|
||||
CHECK_CPU(grad_zbuf);
|
||||
CHECK_CPU(grad_bary);
|
||||
CHECK_CPU(grad_dists);
|
||||
return RasterizeMeshesBackwardCpu(
|
||||
face_verts,
|
||||
pix_to_face,
|
||||
@@ -306,6 +314,9 @@ torch::Tensor RasterizeMeshesCoarse(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(face_verts);
|
||||
CHECK_CPU(mesh_to_face_first_idx);
|
||||
CHECK_CPU(num_faces_per_mesh);
|
||||
return RasterizeMeshesCoarseCpu(
|
||||
face_verts,
|
||||
mesh_to_face_first_idx,
|
||||
@@ -423,6 +434,8 @@ RasterizeMeshesFine(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(face_verts);
|
||||
CHECK_CPU(bin_faces);
|
||||
AT_ERROR("NOT IMPLEMENTED");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,6 +91,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsNaive(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(cloud_to_packed_first_idx);
|
||||
CHECK_CPU(num_points_per_cloud);
|
||||
CHECK_CPU(radius);
|
||||
return RasterizePointsNaiveCpu(
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
@@ -166,6 +170,10 @@ torch::Tensor RasterizePointsCoarse(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(cloud_to_packed_first_idx);
|
||||
CHECK_CPU(num_points_per_cloud);
|
||||
CHECK_CPU(radius);
|
||||
return RasterizePointsCoarseCpu(
|
||||
points,
|
||||
cloud_to_packed_first_idx,
|
||||
@@ -232,6 +240,8 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> RasterizePointsFine(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(bin_points);
|
||||
AT_ERROR("NOT IMPLEMENTED");
|
||||
}
|
||||
}
|
||||
@@ -284,6 +294,10 @@ torch::Tensor RasterizePointsBackward(
|
||||
AT_ERROR("Not compiled with GPU support");
|
||||
#endif
|
||||
} else {
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(idxs);
|
||||
CHECK_CPU(grad_zbuf);
|
||||
CHECK_CPU(grad_dists);
|
||||
return RasterizePointsBackwardCpu(points, idxs, grad_zbuf, grad_dists);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -107,7 +107,8 @@ at::Tensor FarthestPointSamplingCuda(
|
||||
const at::Tensor& points, // (N, P, 3)
|
||||
const at::Tensor& lengths, // (N,)
|
||||
const at::Tensor& K, // (N,)
|
||||
const at::Tensor& start_idxs) {
|
||||
const at::Tensor& start_idxs,
|
||||
const int64_t max_K_known = -1) {
|
||||
// Check inputs are on the same device
|
||||
at::TensorArg p_t{points, "points", 1}, lengths_t{lengths, "lengths", 2},
|
||||
k_t{K, "K", 3}, start_idxs_t{start_idxs, "start_idxs", 4};
|
||||
@@ -129,7 +130,12 @@ at::Tensor FarthestPointSamplingCuda(
|
||||
|
||||
const int64_t N = points.size(0);
|
||||
const int64_t P = points.size(1);
|
||||
const int64_t max_K = at::max(K).item<int64_t>();
|
||||
int64_t max_K;
|
||||
if (max_K_known > 0) {
|
||||
max_K = max_K_known;
|
||||
} else {
|
||||
max_K = at::max(K).item<int64_t>();
|
||||
}
|
||||
|
||||
// Initialize the output tensor with the sampled indices
|
||||
auto idxs = at::full({N, max_K}, -1, lengths.options());
|
||||
|
||||
@@ -43,7 +43,8 @@ at::Tensor FarthestPointSamplingCuda(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& lengths,
|
||||
const at::Tensor& K,
|
||||
const at::Tensor& start_idxs);
|
||||
const at::Tensor& start_idxs,
|
||||
const int64_t max_K_known = -1);
|
||||
|
||||
at::Tensor FarthestPointSamplingCpu(
|
||||
const at::Tensor& points,
|
||||
@@ -56,17 +57,23 @@ at::Tensor FarthestPointSampling(
|
||||
const at::Tensor& points,
|
||||
const at::Tensor& lengths,
|
||||
const at::Tensor& K,
|
||||
const at::Tensor& start_idxs) {
|
||||
const at::Tensor& start_idxs,
|
||||
const int64_t max_K_known = -1) {
|
||||
if (points.is_cuda() || lengths.is_cuda() || K.is_cuda()) {
|
||||
#ifdef WITH_CUDA
|
||||
CHECK_CUDA(points);
|
||||
CHECK_CUDA(lengths);
|
||||
CHECK_CUDA(K);
|
||||
CHECK_CUDA(start_idxs);
|
||||
return FarthestPointSamplingCuda(points, lengths, K, start_idxs);
|
||||
return FarthestPointSamplingCuda(
|
||||
points, lengths, K, start_idxs, max_K_known);
|
||||
#else
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(points);
|
||||
CHECK_CPU(lengths);
|
||||
CHECK_CPU(K);
|
||||
CHECK_CPU(start_idxs);
|
||||
return FarthestPointSamplingCpu(points, lengths, K, start_idxs);
|
||||
}
|
||||
|
||||
@@ -71,6 +71,8 @@ inline void SamplePdf(
|
||||
AT_ERROR("Not compiled with GPU support.");
|
||||
#endif
|
||||
}
|
||||
CHECK_CPU(weights);
|
||||
CHECK_CPU(outputs);
|
||||
CHECK_CONTIGUOUS(outputs);
|
||||
SamplePdfCpu(bins, weights, outputs, eps);
|
||||
}
|
||||
|
||||
@@ -99,8 +99,7 @@ namespace {
|
||||
// and increment it via template recursion until it is equal to the run-time
|
||||
// argument N.
|
||||
template <
|
||||
template <typename, int64_t>
|
||||
class Kernel,
|
||||
template <typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
@@ -124,8 +123,7 @@ struct DispatchKernelHelper1D {
|
||||
// 1D dispatch: Specialization when curN == maxN
|
||||
// We need this base case to avoid infinite template recursion.
|
||||
template <
|
||||
template <typename, int64_t>
|
||||
class Kernel,
|
||||
template <typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
@@ -145,8 +143,7 @@ struct DispatchKernelHelper1D<Kernel, T, minN, maxN, maxN, Args...> {
|
||||
// the run-time values of N and M, at which point we dispatch to the run
|
||||
// method of the kernel.
|
||||
template <
|
||||
template <typename, int64_t, int64_t>
|
||||
class Kernel,
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
@@ -203,8 +200,7 @@ struct DispatchKernelHelper2D {
|
||||
|
||||
// 2D dispatch, specialization for curN == maxN
|
||||
template <
|
||||
template <typename, int64_t, int64_t>
|
||||
class Kernel,
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
@@ -243,8 +239,7 @@ struct DispatchKernelHelper2D<
|
||||
|
||||
// 2D dispatch, specialization for curM == maxM
|
||||
template <
|
||||
template <typename, int64_t, int64_t>
|
||||
class Kernel,
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
@@ -283,8 +278,7 @@ struct DispatchKernelHelper2D<
|
||||
|
||||
// 2D dispatch, specialization for curN == maxN, curM == maxM
|
||||
template <
|
||||
template <typename, int64_t, int64_t>
|
||||
class Kernel,
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
@@ -313,8 +307,7 @@ struct DispatchKernelHelper2D<
|
||||
|
||||
// This is the function we expect users to call to dispatch to 1D functions
|
||||
template <
|
||||
template <typename, int64_t>
|
||||
class Kernel,
|
||||
template <typename, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
@@ -330,8 +323,7 @@ void DispatchKernel1D(const int64_t N, Args... args) {
|
||||
|
||||
// This is the function we expect users to call to dispatch to 2D functions
|
||||
template <
|
||||
template <typename, int64_t, int64_t>
|
||||
class Kernel,
|
||||
template <typename, int64_t, int64_t> class Kernel,
|
||||
typename T,
|
||||
int64_t minN,
|
||||
int64_t maxN,
|
||||
|
||||
@@ -15,3 +15,7 @@
|
||||
#define CHECK_CONTIGUOUS_CUDA(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_CPU(x) \
|
||||
TORCH_CHECK( \
|
||||
x.device().type() == torch::kCPU, \
|
||||
"Cannot use CPU implementation: " #x " not on CPU.")
|
||||
|
||||
@@ -755,7 +755,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
if pick_sequences:
|
||||
old_len = len(eval_batches)
|
||||
eval_batches = [b for b in eval_batches if b[0][0] in pick_sequences]
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Picked eval batches by sequence/cat: {old_len} -> {len(eval_batches)}"
|
||||
)
|
||||
|
||||
@@ -763,7 +763,7 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):
|
||||
old_len = len(eval_batches)
|
||||
exclude_sequences = set(self.exclude_sequences)
|
||||
eval_batches = [b for b in eval_batches if b[0][0] not in exclude_sequences]
|
||||
logger.warn(
|
||||
logger.warning(
|
||||
f"Excluded eval batches by sequence: {old_len} -> {len(eval_batches)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -21,8 +21,6 @@ import logging
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Any, Callable, ClassVar, Dict, Iterator, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
@@ -222,7 +220,8 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module):
|
||||
+ "| 'bicubic' | 'linear' | 'area' | 'nearest-exact'"
|
||||
)
|
||||
|
||||
interpolate_has_antialias = LooseVersion(torch.__version__) >= "1.11"
|
||||
# We assume PyTorch 1.11 and newer.
|
||||
interpolate_has_antialias = True
|
||||
|
||||
if antialias and not interpolate_has_antialias:
|
||||
warnings.warn("Antialiased interpolation requires PyTorch 1.11+; ignoring")
|
||||
|
||||
@@ -304,11 +304,11 @@ def _show_predictions(
|
||||
assert isinstance(preds, list)
|
||||
|
||||
pred_all = []
|
||||
# Randomly choose a subset of the rendered images, sort by ordr in the sequence
|
||||
# Randomly choose a subset of the rendered images, sort by order in the sequence
|
||||
n_samples = min(n_samples, len(preds))
|
||||
pred_idx = sorted(random.sample(list(range(len(preds))), n_samples))
|
||||
for predi in pred_idx:
|
||||
# Make the concatentation for the same camera vertically
|
||||
# Make the concatenation for the same camera vertically
|
||||
pred_all.append(
|
||||
torch.cat(
|
||||
[
|
||||
@@ -359,7 +359,7 @@ def _generate_prediction_videos(
|
||||
vws = {}
|
||||
for k in predicted_keys:
|
||||
if k not in preds[0]:
|
||||
logger.warn(f"Cannot generate video for prediction key '{k}'")
|
||||
logger.warning(f"Cannot generate video for prediction key '{k}'")
|
||||
continue
|
||||
cache_dir = (
|
||||
None
|
||||
|
||||
@@ -23,11 +23,13 @@ class _ball_query(Function):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, p1, p2, lengths1, lengths2, K, radius):
|
||||
def forward(ctx, p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube):
|
||||
"""
|
||||
Arguments defintions the same as in the ball_query function
|
||||
"""
|
||||
idx, dists = _C.ball_query(p1, p2, lengths1, lengths2, K, radius)
|
||||
idx, dists = _C.ball_query(
|
||||
p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube
|
||||
)
|
||||
ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
|
||||
ctx.mark_non_differentiable(idx)
|
||||
return dists, idx
|
||||
@@ -49,7 +51,7 @@ class _ball_query(Function):
|
||||
grad_p1, grad_p2 = _C.knn_points_backward(
|
||||
p1, p2, lengths1, lengths2, idx, 2, grad_dists
|
||||
)
|
||||
return grad_p1, grad_p2, None, None, None, None
|
||||
return grad_p1, grad_p2, None, None, None, None, None
|
||||
|
||||
|
||||
def ball_query(
|
||||
@@ -60,6 +62,7 @@ def ball_query(
|
||||
K: int = 500,
|
||||
radius: float = 0.2,
|
||||
return_nn: bool = True,
|
||||
skip_points_outside_cube: bool = False,
|
||||
):
|
||||
"""
|
||||
Ball Query is an alternative to KNN. It can be
|
||||
@@ -98,6 +101,9 @@ def ball_query(
|
||||
within the radius
|
||||
radius: the radius around each point within which the neighbors need to be located
|
||||
return_nn: If set to True returns the K neighbor points in p2 for each point in p1.
|
||||
skip_points_outside_cube: If set to True, reduce multiplications of float values
|
||||
by not explicitly calculating distances to points that fall outside the
|
||||
D-cube with side length (2*radius) centered at each point in p1.
|
||||
|
||||
Returns:
|
||||
dists: Tensor of shape (N, P1, K) giving the squared distances to
|
||||
@@ -134,7 +140,9 @@ def ball_query(
|
||||
if lengths2 is None:
|
||||
lengths2 = torch.full((N,), P2, dtype=torch.int64, device=p1.device)
|
||||
|
||||
dists, idx = _ball_query.apply(p1, p2, lengths1, lengths2, K, radius)
|
||||
dists, idx = _ball_query.apply(
|
||||
p1, p2, lengths1, lengths2, K, radius, skip_points_outside_cube
|
||||
)
|
||||
|
||||
# Gather the neighbors if needed
|
||||
points_nn = masked_gather(p2, idx) if return_nn else None
|
||||
|
||||
@@ -47,8 +47,7 @@ def laplacian(verts: torch.Tensor, edges: torch.Tensor) -> torch.Tensor:
|
||||
# i.e. A[i, j] = 1 if (i,j) is an edge, or
|
||||
# A[e0, e1] = 1 & A[e1, e0] = 1
|
||||
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device)
|
||||
# pyre-fixme[16]: Module `sparse` has no attribute `FloatTensor`.
|
||||
A = torch.sparse.FloatTensor(idx, ones, (V, V))
|
||||
A = torch.sparse_coo_tensor(idx, ones, (V, V), dtype=torch.float32)
|
||||
|
||||
# the sum of i-th row of A gives the degree of the i-th vertex
|
||||
deg = torch.sparse.sum(A, dim=1).to_dense()
|
||||
@@ -62,15 +61,13 @@ def laplacian(verts: torch.Tensor, edges: torch.Tensor) -> torch.Tensor:
|
||||
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
||||
deg1 = torch.where(deg1 > 0.0, 1.0 / deg1, deg1)
|
||||
val = torch.cat([deg0, deg1])
|
||||
# pyre-fixme[16]: Module `sparse` has no attribute `FloatTensor`.
|
||||
L = torch.sparse.FloatTensor(idx, val, (V, V))
|
||||
L = torch.sparse_coo_tensor(idx, val, (V, V), dtype=torch.float32)
|
||||
|
||||
# Then we add the diagonal values L[i, i] = -1.
|
||||
idx = torch.arange(V, device=verts.device)
|
||||
idx = torch.stack([idx, idx], dim=0)
|
||||
ones = torch.ones(idx.shape[1], dtype=torch.float32, device=verts.device)
|
||||
# pyre-fixme[16]: Module `sparse` has no attribute `FloatTensor`.
|
||||
L -= torch.sparse.FloatTensor(idx, ones, (V, V))
|
||||
L -= torch.sparse_coo_tensor(idx, ones, (V, V), dtype=torch.float32)
|
||||
|
||||
return L
|
||||
|
||||
@@ -126,8 +123,7 @@ def cot_laplacian(
|
||||
ii = faces[:, [1, 2, 0]]
|
||||
jj = faces[:, [2, 0, 1]]
|
||||
idx = torch.stack([ii, jj], dim=0).view(2, F * 3)
|
||||
# pyre-fixme[16]: Module `sparse` has no attribute `FloatTensor`.
|
||||
L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))
|
||||
L = torch.sparse_coo_tensor(idx, cot.view(-1), (V, V), dtype=torch.float32)
|
||||
|
||||
# Make it symmetric; this means we are also setting
|
||||
# L[v2, v1] = cota
|
||||
@@ -167,7 +163,7 @@ def norm_laplacian(
|
||||
v0, v1 = edge_verts[:, 0], edge_verts[:, 1]
|
||||
|
||||
# Side lengths of each edge, of shape (E,)
|
||||
w01 = 1.0 / ((v0 - v1).norm(dim=1) + eps)
|
||||
w01 = torch.reciprocal((v0 - v1).norm(dim=1) + eps)
|
||||
|
||||
# Construct a sparse matrix by basically doing:
|
||||
# L[v0, v1] = w01
|
||||
@@ -175,8 +171,7 @@ def norm_laplacian(
|
||||
e01 = edges.t() # (2, E)
|
||||
|
||||
V = verts.shape[0]
|
||||
# pyre-fixme[16]: Module `sparse` has no attribute `FloatTensor`.
|
||||
L = torch.sparse.FloatTensor(e01, w01, (V, V))
|
||||
L = torch.sparse_coo_tensor(e01, w01, (V, V), dtype=torch.float32)
|
||||
L = L + L.t()
|
||||
|
||||
return L
|
||||
|
||||
@@ -55,6 +55,7 @@ def sample_farthest_points(
|
||||
N, P, D = points.shape
|
||||
device = points.device
|
||||
|
||||
constant_length = lengths is None
|
||||
# Validate inputs
|
||||
if lengths is None:
|
||||
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
|
||||
@@ -65,7 +66,9 @@ def sample_farthest_points(
|
||||
raise ValueError("A value in lengths was too large.")
|
||||
|
||||
# TODO: support providing K as a ratio of the total number of points instead of as an int
|
||||
max_K = -1
|
||||
if isinstance(K, int):
|
||||
max_K = K
|
||||
K = torch.full((N,), K, dtype=torch.int64, device=device)
|
||||
elif isinstance(K, list):
|
||||
K = torch.tensor(K, dtype=torch.int64, device=device)
|
||||
@@ -82,15 +85,19 @@ def sample_farthest_points(
|
||||
K = K.to(torch.int64)
|
||||
|
||||
# Generate the starting indices for sampling
|
||||
start_idxs = torch.zeros_like(lengths)
|
||||
if random_start_point:
|
||||
for n in range(N):
|
||||
# pyre-fixme[6]: For 1st param expected `int` but got `Tensor`.
|
||||
start_idxs[n] = torch.randint(high=lengths[n], size=(1,)).item()
|
||||
if constant_length:
|
||||
start_idxs = torch.randint(high=P, size=(N,), device=device)
|
||||
else:
|
||||
start_idxs = (lengths * torch.rand(lengths.size(), device=device)).to(
|
||||
torch.int64
|
||||
)
|
||||
else:
|
||||
start_idxs = torch.zeros_like(lengths)
|
||||
|
||||
with torch.no_grad():
|
||||
# pyre-fixme[16]: `pytorch3d_._C` has no attribute `sample_farthest_points`.
|
||||
idx = _C.sample_farthest_points(points, lengths, K, start_idxs)
|
||||
idx = _C.sample_farthest_points(points, lengths, K, start_idxs, max_K)
|
||||
sampled_points = masked_gather(points, idx)
|
||||
|
||||
return sampled_points, idx
|
||||
|
||||
@@ -160,9 +160,10 @@ def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
||||
# forall i; we pick the best-conditioned one (with the largest denominator)
|
||||
out = quat_candidates[
|
||||
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
|
||||
].reshape(batch_dim + (4,))
|
||||
indices = q_abs.argmax(dim=-1, keepdim=True)
|
||||
expand_dims = list(batch_dim) + [1, 4]
|
||||
gather_indices = indices.unsqueeze(-1).expand(expand_dims)
|
||||
out = torch.gather(quat_candidates, -2, gather_indices).squeeze(-2)
|
||||
return standardize_quaternion(out)
|
||||
|
||||
|
||||
@@ -293,10 +294,11 @@ def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tenso
|
||||
tait_bryan = i0 != i2
|
||||
if tait_bryan:
|
||||
central_angle = torch.asin(
|
||||
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
|
||||
torch.clamp(matrix[..., i0, i2], -1.0, 1.0)
|
||||
* (-1.0 if i0 - i2 in [-1, 2] else 1.0)
|
||||
)
|
||||
else:
|
||||
central_angle = torch.acos(matrix[..., i0, i0])
|
||||
central_angle = torch.acos(torch.clamp(matrix[..., i0, i0], -1.0, 1.0))
|
||||
|
||||
o = (
|
||||
_angle_from_tan(
|
||||
|
||||
17
setup.py
17
setup.py
@@ -75,6 +75,21 @@ def get_extensions():
|
||||
]
|
||||
if os.name != "nt":
|
||||
nvcc_args.append("-std=c++17")
|
||||
|
||||
# CUDA 13.0+ compatibility flags for pulsar.
|
||||
# Starting with CUDA 13, __global__ function visibility changed.
|
||||
# See: https://developer.nvidia.com/blog/
|
||||
# cuda-c-compiler-updates-impacting-elf-visibility-and-linkage/
|
||||
cuda_version = torch.version.cuda
|
||||
if cuda_version is not None:
|
||||
major = int(cuda_version.split(".")[0])
|
||||
if major >= 13:
|
||||
nvcc_args.extend(
|
||||
[
|
||||
"--device-entity-has-hidden-visibility=false",
|
||||
"-static-global-template-stub=false",
|
||||
]
|
||||
)
|
||||
if cub_home is None:
|
||||
prefix = os.environ.get("CONDA_PREFIX", None)
|
||||
if prefix is not None and os.path.isdir(prefix + "/include/cub"):
|
||||
@@ -134,7 +149,7 @@ if os.getenv("PYTORCH3D_NO_NINJA", "0") == "1":
|
||||
|
||||
class BuildExtension(torch.utils.cpp_extension.BuildExtension):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(use_ninja=False, *args, **kwargs)
|
||||
super().__init__(*args, use_ninja=False, **kwargs)
|
||||
|
||||
else:
|
||||
BuildExtension = torch.utils.cpp_extension.BuildExtension
|
||||
|
||||
56
tests/benchmarks/bm_ball_query_large.py
Normal file
56
tests/benchmarks/bm_ball_query_large.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from itertools import product
|
||||
|
||||
import torch
|
||||
from fvcore.common.benchmark import benchmark
|
||||
|
||||
from pytorch3d.ops.ball_query import ball_query
|
||||
|
||||
|
||||
def ball_query_square(
|
||||
N: int, P1: int, P2: int, D: int, K: int, radius: float, device: str
|
||||
):
|
||||
device = torch.device(device)
|
||||
pts1 = torch.rand(N, P1, D, device=device)
|
||||
pts2 = torch.rand(N, P2, D, device=device)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
def output():
|
||||
ball_query(pts1, pts2, K=K, radius=radius, skip_points_outside_cube=True)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def bm_ball_query() -> None:
|
||||
backends = ["cpu", "cuda:0"]
|
||||
|
||||
kwargs_list = []
|
||||
Ns = [32]
|
||||
P1s = [256]
|
||||
P2s = [2**p for p in range(9, 20, 2)]
|
||||
Ds = [3, 10]
|
||||
Ks = [500]
|
||||
Rs = [0.01, 0.1]
|
||||
test_cases = product(Ns, P1s, P2s, Ds, Ks, Rs, backends)
|
||||
for case in test_cases:
|
||||
N, P1, P2, D, K, R, b = case
|
||||
kwargs_list.append(
|
||||
{"N": N, "P1": P1, "P2": P2, "D": D, "K": K, "radius": R, "device": b}
|
||||
)
|
||||
benchmark(
|
||||
ball_query_square,
|
||||
"BALLQUERY_SQUARE",
|
||||
kwargs_list,
|
||||
num_iters=30,
|
||||
warmup_iters=1,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bm_ball_query()
|
||||
@@ -31,6 +31,13 @@ def skip_opengl_requested() -> bool:
|
||||
usesOpengl = unittest.skipIf(skip_opengl_requested(), "uses opengl")
|
||||
|
||||
|
||||
def have_multiple_gpus() -> bool:
|
||||
return torch.cuda.device_count() > 1
|
||||
|
||||
|
||||
needs_multigpu = unittest.skipIf(not have_multiple_gpus(), "needs multiple GPUs")
|
||||
|
||||
|
||||
def get_tests_dir() -> Path:
|
||||
"""
|
||||
Returns Path for the directory containing this file.
|
||||
|
||||
@@ -15,7 +15,7 @@ from tests.common_testing import get_pytorch3d_dir
|
||||
|
||||
# This file groups together tests which look at the code without running it.
|
||||
class TestBuild(unittest.TestCase):
|
||||
def test_no_import_cycles(self):
|
||||
def _test_no_import_cycles(self):
|
||||
# Check each module of pytorch3d imports cleanly,
|
||||
# which may fail if there are import cycles.
|
||||
|
||||
|
||||
@@ -78,7 +78,7 @@ class TestBuild(unittest.TestCase):
|
||||
|
||||
self.assertListEqual(sorted(listed_in_json), notes_on_disk)
|
||||
|
||||
def test_no_import_cycles(self):
|
||||
def _test_no_import_cycles(self):
|
||||
# Check each module of pytorch3d imports cleanly,
|
||||
# which may fail if there are import cycles.
|
||||
|
||||
|
||||
@@ -72,6 +72,7 @@ class TestKNN(TestCaseMixin, unittest.TestCase):
|
||||
factors = [Ns, Ds, P1s, P2s, Ks, norms]
|
||||
for N, D, P1, P2, K, norm in product(*factors):
|
||||
for version in versions:
|
||||
torch.manual_seed(2)
|
||||
if version == 3 and K > 4:
|
||||
continue
|
||||
x = torch.randn(N, P1, D, device=device, requires_grad=True)
|
||||
|
||||
@@ -703,80 +703,6 @@ class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(cuda_device, cloud.device)
|
||||
self.assertIsNot(cloud, converted_cloud)
|
||||
|
||||
def test_to_list(self):
|
||||
cloud = self.init_cloud(5, 100, 10)
|
||||
device = torch.device("cuda:1")
|
||||
|
||||
new_cloud = cloud.to(device)
|
||||
self.assertTrue(new_cloud.device == device)
|
||||
self.assertTrue(cloud.device == torch.device("cuda:0"))
|
||||
for attrib in [
|
||||
"points_padded",
|
||||
"points_packed",
|
||||
"normals_padded",
|
||||
"normals_packed",
|
||||
"features_padded",
|
||||
"features_packed",
|
||||
"num_points_per_cloud",
|
||||
"cloud_to_packed_first_idx",
|
||||
"padded_to_packed_idx",
|
||||
]:
|
||||
self.assertClose(
|
||||
getattr(new_cloud, attrib)().cpu(), getattr(cloud, attrib)().cpu()
|
||||
)
|
||||
for i in range(len(cloud)):
|
||||
self.assertClose(
|
||||
cloud.points_list()[i].cpu(), new_cloud.points_list()[i].cpu()
|
||||
)
|
||||
self.assertClose(
|
||||
cloud.normals_list()[i].cpu(), new_cloud.normals_list()[i].cpu()
|
||||
)
|
||||
self.assertClose(
|
||||
cloud.features_list()[i].cpu(), new_cloud.features_list()[i].cpu()
|
||||
)
|
||||
self.assertTrue(all(cloud.valid.cpu() == new_cloud.valid.cpu()))
|
||||
self.assertTrue(cloud.equisized == new_cloud.equisized)
|
||||
self.assertTrue(cloud._N == new_cloud._N)
|
||||
self.assertTrue(cloud._P == new_cloud._P)
|
||||
self.assertTrue(cloud._C == new_cloud._C)
|
||||
|
||||
def test_to_tensor(self):
|
||||
cloud = self.init_cloud(5, 100, 10, lists_to_tensors=True)
|
||||
device = torch.device("cuda:1")
|
||||
|
||||
new_cloud = cloud.to(device)
|
||||
self.assertTrue(new_cloud.device == device)
|
||||
self.assertTrue(cloud.device == torch.device("cuda:0"))
|
||||
for attrib in [
|
||||
"points_padded",
|
||||
"points_packed",
|
||||
"normals_padded",
|
||||
"normals_packed",
|
||||
"features_padded",
|
||||
"features_packed",
|
||||
"num_points_per_cloud",
|
||||
"cloud_to_packed_first_idx",
|
||||
"padded_to_packed_idx",
|
||||
]:
|
||||
self.assertClose(
|
||||
getattr(new_cloud, attrib)().cpu(), getattr(cloud, attrib)().cpu()
|
||||
)
|
||||
for i in range(len(cloud)):
|
||||
self.assertClose(
|
||||
cloud.points_list()[i].cpu(), new_cloud.points_list()[i].cpu()
|
||||
)
|
||||
self.assertClose(
|
||||
cloud.normals_list()[i].cpu(), new_cloud.normals_list()[i].cpu()
|
||||
)
|
||||
self.assertClose(
|
||||
cloud.features_list()[i].cpu(), new_cloud.features_list()[i].cpu()
|
||||
)
|
||||
self.assertTrue(all(cloud.valid.cpu() == new_cloud.valid.cpu()))
|
||||
self.assertTrue(cloud.equisized == new_cloud.equisized)
|
||||
self.assertTrue(cloud._N == new_cloud._N)
|
||||
self.assertTrue(cloud._P == new_cloud._P)
|
||||
self.assertTrue(cloud._C == new_cloud._C)
|
||||
|
||||
def test_split(self):
|
||||
clouds = self.init_cloud(5, 100, 10)
|
||||
split_sizes = [2, 3]
|
||||
|
||||
166
tests/test_pointclouds_multigpu.py
Normal file
166
tests/test_pointclouds_multigpu.py
Normal file
@@ -0,0 +1,166 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from pytorch3d.structures.pointclouds import Pointclouds
|
||||
|
||||
from .common_testing import needs_multigpu, TestCaseMixin
|
||||
|
||||
|
||||
class TestPointclouds(TestCaseMixin, unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
np.random.seed(42)
|
||||
torch.manual_seed(42)
|
||||
|
||||
@staticmethod
|
||||
def init_cloud(
|
||||
num_clouds: int = 3,
|
||||
max_points: int = 100,
|
||||
channels: int = 4,
|
||||
lists_to_tensors: bool = False,
|
||||
with_normals: bool = True,
|
||||
with_features: bool = True,
|
||||
min_points: int = 0,
|
||||
requires_grad: bool = False,
|
||||
):
|
||||
"""
|
||||
Function to generate a Pointclouds object of N meshes with
|
||||
random number of points.
|
||||
|
||||
Args:
|
||||
num_clouds: Number of clouds to generate.
|
||||
channels: Number of features.
|
||||
max_points: Max number of points per cloud.
|
||||
lists_to_tensors: Determines whether the generated clouds should be
|
||||
constructed from lists (=False) or
|
||||
tensors (=True) of points/normals/features.
|
||||
with_normals: bool whether to include normals
|
||||
with_features: bool whether to include features
|
||||
min_points: Min number of points per cloud
|
||||
|
||||
Returns:
|
||||
Pointclouds object.
|
||||
"""
|
||||
device = torch.device("cuda:0")
|
||||
p = torch.randint(low=min_points, high=max_points, size=(num_clouds,))
|
||||
if lists_to_tensors:
|
||||
p.fill_(p[0])
|
||||
|
||||
points_list = [
|
||||
torch.rand(
|
||||
(i, 3), device=device, dtype=torch.float32, requires_grad=requires_grad
|
||||
)
|
||||
for i in p
|
||||
]
|
||||
normals_list, features_list = None, None
|
||||
if with_normals:
|
||||
normals_list = [
|
||||
torch.rand(
|
||||
(i, 3),
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
for i in p
|
||||
]
|
||||
if with_features:
|
||||
features_list = [
|
||||
torch.rand(
|
||||
(i, channels),
|
||||
device=device,
|
||||
dtype=torch.float32,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
for i in p
|
||||
]
|
||||
|
||||
if lists_to_tensors:
|
||||
points_list = torch.stack(points_list)
|
||||
if with_normals:
|
||||
normals_list = torch.stack(normals_list)
|
||||
if with_features:
|
||||
features_list = torch.stack(features_list)
|
||||
|
||||
return Pointclouds(points_list, normals=normals_list, features=features_list)
|
||||
|
||||
@needs_multigpu
|
||||
def test_to_list(self):
|
||||
cloud = self.init_cloud(5, 100, 10)
|
||||
device = torch.device("cuda:1")
|
||||
|
||||
new_cloud = cloud.to(device)
|
||||
self.assertTrue(new_cloud.device == device)
|
||||
self.assertTrue(cloud.device == torch.device("cuda:0"))
|
||||
for attrib in [
|
||||
"points_padded",
|
||||
"points_packed",
|
||||
"normals_padded",
|
||||
"normals_packed",
|
||||
"features_padded",
|
||||
"features_packed",
|
||||
"num_points_per_cloud",
|
||||
"cloud_to_packed_first_idx",
|
||||
"padded_to_packed_idx",
|
||||
]:
|
||||
self.assertClose(
|
||||
getattr(new_cloud, attrib)().cpu(), getattr(cloud, attrib)().cpu()
|
||||
)
|
||||
for i in range(len(cloud)):
|
||||
self.assertClose(
|
||||
cloud.points_list()[i].cpu(), new_cloud.points_list()[i].cpu()
|
||||
)
|
||||
self.assertClose(
|
||||
cloud.normals_list()[i].cpu(), new_cloud.normals_list()[i].cpu()
|
||||
)
|
||||
self.assertClose(
|
||||
cloud.features_list()[i].cpu(), new_cloud.features_list()[i].cpu()
|
||||
)
|
||||
self.assertTrue(all(cloud.valid.cpu() == new_cloud.valid.cpu()))
|
||||
self.assertTrue(cloud.equisized == new_cloud.equisized)
|
||||
self.assertTrue(cloud._N == new_cloud._N)
|
||||
self.assertTrue(cloud._P == new_cloud._P)
|
||||
self.assertTrue(cloud._C == new_cloud._C)
|
||||
|
||||
@needs_multigpu
|
||||
def test_to_tensor(self):
|
||||
cloud = self.init_cloud(5, 100, 10, lists_to_tensors=True)
|
||||
device = torch.device("cuda:1")
|
||||
|
||||
new_cloud = cloud.to(device)
|
||||
self.assertTrue(new_cloud.device == device)
|
||||
self.assertTrue(cloud.device == torch.device("cuda:0"))
|
||||
for attrib in [
|
||||
"points_padded",
|
||||
"points_packed",
|
||||
"normals_padded",
|
||||
"normals_packed",
|
||||
"features_padded",
|
||||
"features_packed",
|
||||
"num_points_per_cloud",
|
||||
"cloud_to_packed_first_idx",
|
||||
"padded_to_packed_idx",
|
||||
]:
|
||||
self.assertClose(
|
||||
getattr(new_cloud, attrib)().cpu(), getattr(cloud, attrib)().cpu()
|
||||
)
|
||||
for i in range(len(cloud)):
|
||||
self.assertClose(
|
||||
cloud.points_list()[i].cpu(), new_cloud.points_list()[i].cpu()
|
||||
)
|
||||
self.assertClose(
|
||||
cloud.normals_list()[i].cpu(), new_cloud.normals_list()[i].cpu()
|
||||
)
|
||||
self.assertClose(
|
||||
cloud.features_list()[i].cpu(), new_cloud.features_list()[i].cpu()
|
||||
)
|
||||
self.assertTrue(all(cloud.valid.cpu() == new_cloud.valid.cpu()))
|
||||
self.assertTrue(cloud.equisized == new_cloud.equisized)
|
||||
self.assertTrue(cloud._N == new_cloud._N)
|
||||
self.assertTrue(cloud._P == new_cloud._P)
|
||||
self.assertTrue(cloud._C == new_cloud._C)
|
||||
@@ -165,7 +165,7 @@ class TestICP(TestCaseMixin, unittest.TestCase):
|
||||
a set of randomly-sized Pointclouds and on their padded versions.
|
||||
"""
|
||||
|
||||
torch.manual_seed(4)
|
||||
torch.manual_seed(14)
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
for estimate_scale in (True, False):
|
||||
|
||||
@@ -29,7 +29,7 @@ from pytorch3d.renderer.opengl import MeshRasterizerOpenGL
|
||||
from pytorch3d.structures import Meshes, Pointclouds
|
||||
from pytorch3d.utils.ico_sphere import ico_sphere
|
||||
|
||||
from .common_testing import TestCaseMixin, usesOpengl
|
||||
from .common_testing import needs_multigpu, TestCaseMixin, usesOpengl
|
||||
|
||||
|
||||
# Set the number of GPUS you want to test with
|
||||
@@ -116,6 +116,7 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
|
||||
output_images = renderer(mesh)
|
||||
self.assertEqual(output_images.device, device2)
|
||||
|
||||
@needs_multigpu
|
||||
def test_mesh_renderer_to(self):
|
||||
self._mesh_renderer_to(MeshRasterizer, SoftPhongShader)
|
||||
|
||||
@@ -173,6 +174,7 @@ class TestRenderMeshesMultiGPU(TestCaseMixin, unittest.TestCase):
|
||||
for _ in range(100):
|
||||
model(verts, texs)
|
||||
|
||||
@needs_multigpu
|
||||
def test_render_meshes(self):
|
||||
self._render_meshes(MeshRasterizer, HardGouraudShader)
|
||||
|
||||
|
||||
@@ -63,9 +63,6 @@ class TestTensorProperties(TestCaseMixin, unittest.TestCase):
|
||||
self.assertEqual(example_gpu.device.type, "cuda")
|
||||
self.assertIsNotNone(example_gpu.device.index)
|
||||
|
||||
example_gpu1 = example.cuda(1)
|
||||
self.assertEqual(example_gpu1.device, torch.device("cuda:1"))
|
||||
|
||||
def test_clone(self):
|
||||
# Check clone method
|
||||
example = TensorPropertiesTestClass(x=10.0, y=(100.0, 200.0))
|
||||
|
||||
@@ -8,7 +8,6 @@
|
||||
import itertools
|
||||
import math
|
||||
import unittest
|
||||
from distutils.version import LooseVersion
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -271,7 +270,6 @@ class TestRotationConversion(TestCaseMixin, unittest.TestCase):
|
||||
torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6
|
||||
)
|
||||
|
||||
@unittest.skipIf(LooseVersion(torch.__version__) < "1.9", "recent torchscript only")
|
||||
def test_scriptable(self):
|
||||
torch.jit.script(axis_angle_to_matrix)
|
||||
torch.jit.script(axis_angle_to_quaternion)
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
|
||||
import math
|
||||
import unittest
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -255,7 +254,6 @@ class TestSO3(TestCaseMixin, unittest.TestCase):
|
||||
# all grad values have to be finite
|
||||
self.assertTrue(torch.isfinite(r.grad).all())
|
||||
|
||||
@unittest.skipIf(LooseVersion(torch.__version__) < "1.9", "recent torchscript only")
|
||||
def test_scriptable(self):
|
||||
torch.jit.script(so3_exp_map)
|
||||
torch.jit.script(so3_log_map)
|
||||
|
||||
Reference in New Issue
Block a user