CUDA marching_cubes fix

Summary:
Fix an inclusive vs exclusive scan mix-up that was accidentally introduced when removing the Thrust dependency (`Thrust::exclusive_scan`) and reimplementing it using `at::cumsum` (which does an inclusive scan).

This fixes two Github reported issues:

 * https://github.com/facebookresearch/pytorch3d/issues/1731
 * https://github.com/facebookresearch/pytorch3d/issues/1751

Reviewed By: bottler

Differential Revision: D54605545

fbshipit-source-id: da9e92f3f8a9a35f7b7191428d0b9a9ca03e0d4d
This commit is contained in:
Jaap Suter 2024-03-07 15:38:24 -08:00 committed by Facebook GitHub Bot
parent a27755db41
commit 7566530669

View File

@ -382,6 +382,44 @@ __global__ void GenerateFacesKernel(
} // end for grid-strided kernel
}
// ATen/Torch does not have an exclusive-scan operator. Additionally, in the
// code below we need to get the "total number of items to work on" after
// a scan, which with an inclusive-scan would simply be the value of the last
// element in the tensor.
//
// This utility function hits two birds with one stone, by running
// an inclusive-scan into a right-shifted view of a tensor that's
// allocated to be one element bigger than the input tensor.
//
// Note; return tensor is `int64_t` per element, even if the input
// tensor is only 32-bit. Also, the return tensor is one element bigger
// than the input one.
//
// Secondary optional argument is an output argument that gets the
// value of the last element of the return tensor (because you almost
// always need this CPU-side right after this function anyway).
static at::Tensor ExclusiveScanAndTotal(
const at::Tensor& inTensor,
int64_t* optTotal = nullptr) {
const auto inSize = inTensor.sizes()[0];
auto retTensor = at::zeros({inSize + 1}, at::kLong).to(inTensor.device());
using at::indexing::None;
using at::indexing::Slice;
auto rightShiftedView = retTensor.index({Slice(1, None)});
// Do an (inclusive-scan) cumulative sum in to the view that's
// shifted one element to the right...
at::cumsum_out(rightShiftedView, inTensor, 0, at::kLong);
if (optTotal) {
*optTotal = retTensor[inSize].cpu().item<int64_t>();
}
// ...so that the not-shifted tensor holds the exclusive-scan
return retTensor;
}
// Entrance for marching cubes cuda extension. Marching Cubes is an algorithm to
// create triangle meshes from an implicit function (one of the form f(x, y, z)
// = 0). It works by iteratively checking a grid of cubes superimposed over a
@ -444,20 +482,18 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
using at::indexing::Slice;
auto d_voxelVerts =
at::zeros({numVoxels + 1}, at::TensorOptions().dtype(at::kInt))
at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
.to(vol.device());
auto d_voxelVerts_ = d_voxelVerts.index({Slice(1, None)});
auto d_voxelOccupied =
at::zeros({numVoxels + 1}, at::TensorOptions().dtype(at::kInt))
at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
.to(vol.device());
auto d_voxelOccupied_ = d_voxelOccupied.index({Slice(1, None)});
// Execute "ClassifyVoxelKernel" kernel to precompute
// two arrays - d_voxelOccupied and d_voxelVertices to global memory,
// which stores the occupancy state and number of voxel vertices per voxel.
ClassifyVoxelKernel<<<grid, threads, 0, stream>>>(
d_voxelVerts_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupied_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelVerts.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
isolevel);
AT_CUDA_CHECK(cudaGetLastError());
@ -467,12 +503,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
// count for voxels in the grid and compute the number of active voxels.
// If the number of active voxels is 0, return zero tensor for verts and
// faces.
auto d_voxelOccupiedScan = at::cumsum(d_voxelOccupied, 0);
auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)});
// number of active voxels
int64_t activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<int64_t>();
int64_t activeVoxels = 0;
auto d_voxelOccupiedScan =
ExclusiveScanAndTotal(d_voxelOccupied, &activeVoxels);
const int device_id = vol.device().index();
auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
@ -487,24 +520,21 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
return std::make_tuple(verts, faces, ids);
}
// Execute "CompactVoxelsKernel" kernel to compress voxels for accleration.
// Execute "CompactVoxelsKernel" kernel to compress voxels for acceleration.
// This allows us to run triangle generation on only the occupied voxels.
auto d_compVoxelArray = at::zeros({activeVoxels}, opt);
CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupiedScan_
d_voxelOccupiedScan
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
numVoxels);
AT_CUDA_CHECK(cudaGetLastError());
cudaDeviceSynchronize();
// Scan d_voxelVerts array to generate offsets of vertices for each voxel
auto d_voxelVertsScan = at::cumsum(d_voxelVerts, 0);
auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)});
// total number of vertices
int64_t totalVerts = d_voxelVertsScan[numVoxels].cpu().item<int64_t>();
int64_t totalVerts = 0;
auto d_voxelVertsScan = ExclusiveScanAndTotal(d_voxelVerts, &totalVerts);
// Execute "GenerateFacesKernel" kernel
// This runs only on the occupied voxels.
@ -524,7 +554,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelVertsScan_.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
d_voxelVertsScan.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
activeVoxels,
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),