marching_cubes type fix

Summary: fixes https://github.com/facebookresearch/pytorch3d/issues/1679

Reviewed By: MichaelRamamonjisoa

Differential Revision: D50949933

fbshipit-source-id: 5c467de8bf84dd2a3d61748b3846678582d24ea3
This commit is contained in:
Jeremy Reizenstein
2023-11-14 07:38:54 -08:00
committed by Facebook GitHub Bot
parent 2f11ddc5ee
commit f613682551
2 changed files with 11 additions and 8 deletions

View File

@@ -223,7 +223,7 @@ __global__ void CompactVoxelsKernel(
compactedVoxelArray,
const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
voxelOccupied,
const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
const at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits>
voxelOccupiedScan,
uint numVoxels) {
uint id = blockIdx.x * blockDim.x + threadIdx.x;
@@ -255,7 +255,7 @@ __global__ void GenerateFacesKernel(
at::PackedTensorAccessor<int64_t, 1, at::RestrictPtrTraits> ids,
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
compactedVoxelArray,
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> numVertsScanned,
at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits> numVertsScanned,
const uint activeVoxels,
const at::PackedTensorAccessor32<float, 3, at::RestrictPtrTraits> vol,
const at::PackedTensorAccessor32<int, 2, at::RestrictPtrTraits> faceTable,
@@ -471,7 +471,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)});
// number of active voxels
int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<int>();
int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<long>();
const int device_id = vol.device().index();
auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
@@ -492,7 +492,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupiedScan_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupiedScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
numVoxels);
AT_CUDA_CHECK(cudaGetLastError());
cudaDeviceSynchronize();
@@ -502,7 +502,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)});
// total number of vertices
int totalVerts = d_voxelVertsScan[numVoxels].cpu().item<int>();
int totalVerts = d_voxelVertsScan[numVoxels].cpu().item<long>();
// Execute "GenerateFacesKernel" kernel
// This runs only on the occupied voxels.
@@ -522,7 +522,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelVertsScan_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelVertsScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
activeVoxels,
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),