mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
marching_cubes type fix
Summary: fixes https://github.com/facebookresearch/pytorch3d/issues/1679 Reviewed By: MichaelRamamonjisoa Differential Revision: D50949933 fbshipit-source-id: 5c467de8bf84dd2a3d61748b3846678582d24ea3
This commit is contained in:
parent
2f11ddc5ee
commit
f613682551
@ -223,7 +223,7 @@ __global__ void CompactVoxelsKernel(
|
||||
compactedVoxelArray,
|
||||
const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
|
||||
voxelOccupied,
|
||||
const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
|
||||
const at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits>
|
||||
voxelOccupiedScan,
|
||||
uint numVoxels) {
|
||||
uint id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
@ -255,7 +255,7 @@ __global__ void GenerateFacesKernel(
|
||||
at::PackedTensorAccessor<int64_t, 1, at::RestrictPtrTraits> ids,
|
||||
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
|
||||
compactedVoxelArray,
|
||||
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> numVertsScanned,
|
||||
at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits> numVertsScanned,
|
||||
const uint activeVoxels,
|
||||
const at::PackedTensorAccessor32<float, 3, at::RestrictPtrTraits> vol,
|
||||
const at::PackedTensorAccessor32<int, 2, at::RestrictPtrTraits> faceTable,
|
||||
@ -471,7 +471,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
|
||||
auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)});
|
||||
|
||||
// number of active voxels
|
||||
int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<int>();
|
||||
int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<long>();
|
||||
|
||||
const int device_id = vol.device().index();
|
||||
auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
|
||||
@ -492,7 +492,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
|
||||
CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
|
||||
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
d_voxelOccupiedScan_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
d_voxelOccupiedScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
|
||||
numVoxels);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
cudaDeviceSynchronize();
|
||||
@ -502,7 +502,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
|
||||
auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)});
|
||||
|
||||
// total number of vertices
|
||||
int totalVerts = d_voxelVertsScan[numVoxels].cpu().item<int>();
|
||||
int totalVerts = d_voxelVertsScan[numVoxels].cpu().item<long>();
|
||||
|
||||
// Execute "GenerateFacesKernel" kernel
|
||||
// This runs only on the occupied voxels.
|
||||
@ -522,7 +522,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
|
||||
faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
|
||||
ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
|
||||
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
d_voxelVertsScan_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
d_voxelVertsScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
|
||||
activeVoxels,
|
||||
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
|
||||
faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),
|
||||
|
@ -939,8 +939,11 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
|
||||
u = u[None].float()
|
||||
verts, faces = marching_cubes_naive(u, 0, return_local_coords=False)
|
||||
verts2, faces2 = marching_cubes(u, 0, return_local_coords=False)
|
||||
self.assertClose(verts[0], verts2[0])
|
||||
self.assertClose(faces[0], faces2[0])
|
||||
self.assertClose(verts2[0], verts[0])
|
||||
self.assertClose(faces2[0], faces[0])
|
||||
verts3, faces3 = marching_cubes(u.cuda(), 0, return_local_coords=False)
|
||||
self.assertEqual(len(verts3), len(verts))
|
||||
self.assertEqual(len(faces3), len(faces))
|
||||
|
||||
@staticmethod
|
||||
def marching_cubes_with_init(algo_type: str, batch_size: int, V: int, device: str):
|
||||
|
Loading…
x
Reference in New Issue
Block a user