mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	marching_cubes type fix
Summary: fixes https://github.com/facebookresearch/pytorch3d/issues/1679 Reviewed By: MichaelRamamonjisoa Differential Revision: D50949933 fbshipit-source-id: 5c467de8bf84dd2a3d61748b3846678582d24ea3
This commit is contained in:
		
							parent
							
								
									2f11ddc5ee
								
							
						
					
					
						commit
						f613682551
					
				@ -223,7 +223,7 @@ __global__ void CompactVoxelsKernel(
 | 
				
			|||||||
        compactedVoxelArray,
 | 
					        compactedVoxelArray,
 | 
				
			||||||
    const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
 | 
					    const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
 | 
				
			||||||
        voxelOccupied,
 | 
					        voxelOccupied,
 | 
				
			||||||
    const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
 | 
					    const at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits>
 | 
				
			||||||
        voxelOccupiedScan,
 | 
					        voxelOccupiedScan,
 | 
				
			||||||
    uint numVoxels) {
 | 
					    uint numVoxels) {
 | 
				
			||||||
  uint id = blockIdx.x * blockDim.x + threadIdx.x;
 | 
					  uint id = blockIdx.x * blockDim.x + threadIdx.x;
 | 
				
			||||||
@ -255,7 +255,7 @@ __global__ void GenerateFacesKernel(
 | 
				
			|||||||
    at::PackedTensorAccessor<int64_t, 1, at::RestrictPtrTraits> ids,
 | 
					    at::PackedTensorAccessor<int64_t, 1, at::RestrictPtrTraits> ids,
 | 
				
			||||||
    at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
 | 
					    at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
 | 
				
			||||||
        compactedVoxelArray,
 | 
					        compactedVoxelArray,
 | 
				
			||||||
    at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> numVertsScanned,
 | 
					    at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits> numVertsScanned,
 | 
				
			||||||
    const uint activeVoxels,
 | 
					    const uint activeVoxels,
 | 
				
			||||||
    const at::PackedTensorAccessor32<float, 3, at::RestrictPtrTraits> vol,
 | 
					    const at::PackedTensorAccessor32<float, 3, at::RestrictPtrTraits> vol,
 | 
				
			||||||
    const at::PackedTensorAccessor32<int, 2, at::RestrictPtrTraits> faceTable,
 | 
					    const at::PackedTensorAccessor32<int, 2, at::RestrictPtrTraits> faceTable,
 | 
				
			||||||
@ -471,7 +471,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
 | 
				
			|||||||
  auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)});
 | 
					  auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // number of active voxels
 | 
					  // number of active voxels
 | 
				
			||||||
  int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<int>();
 | 
					  int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<long>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  const int device_id = vol.device().index();
 | 
					  const int device_id = vol.device().index();
 | 
				
			||||||
  auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
 | 
					  auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
 | 
				
			||||||
@ -492,7 +492,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
 | 
				
			|||||||
  CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
 | 
					  CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
 | 
				
			||||||
      d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
					      d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
				
			||||||
      d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
					      d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
				
			||||||
      d_voxelOccupiedScan_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
					      d_voxelOccupiedScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
 | 
				
			||||||
      numVoxels);
 | 
					      numVoxels);
 | 
				
			||||||
  AT_CUDA_CHECK(cudaGetLastError());
 | 
					  AT_CUDA_CHECK(cudaGetLastError());
 | 
				
			||||||
  cudaDeviceSynchronize();
 | 
					  cudaDeviceSynchronize();
 | 
				
			||||||
@ -502,7 +502,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
 | 
				
			|||||||
  auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)});
 | 
					  auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)});
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // total number of vertices
 | 
					  // total number of vertices
 | 
				
			||||||
  int totalVerts = d_voxelVertsScan[numVoxels].cpu().item<int>();
 | 
					  int totalVerts = d_voxelVertsScan[numVoxels].cpu().item<long>();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  // Execute "GenerateFacesKernel" kernel
 | 
					  // Execute "GenerateFacesKernel" kernel
 | 
				
			||||||
  // This runs only on the occupied voxels.
 | 
					  // This runs only on the occupied voxels.
 | 
				
			||||||
@ -522,7 +522,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
 | 
				
			|||||||
      faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
 | 
					      faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
 | 
				
			||||||
      ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
 | 
					      ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
 | 
				
			||||||
      d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
					      d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
				
			||||||
      d_voxelVertsScan_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
					      d_voxelVertsScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
 | 
				
			||||||
      activeVoxels,
 | 
					      activeVoxels,
 | 
				
			||||||
      vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
 | 
					      vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
 | 
				
			||||||
      faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),
 | 
					      faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),
 | 
				
			||||||
 | 
				
			|||||||
@ -939,8 +939,11 @@ class TestMarchingCubes(TestCaseMixin, unittest.TestCase):
 | 
				
			|||||||
        u = u[None].float()
 | 
					        u = u[None].float()
 | 
				
			||||||
        verts, faces = marching_cubes_naive(u, 0, return_local_coords=False)
 | 
					        verts, faces = marching_cubes_naive(u, 0, return_local_coords=False)
 | 
				
			||||||
        verts2, faces2 = marching_cubes(u, 0, return_local_coords=False)
 | 
					        verts2, faces2 = marching_cubes(u, 0, return_local_coords=False)
 | 
				
			||||||
        self.assertClose(verts[0], verts2[0])
 | 
					        self.assertClose(verts2[0], verts[0])
 | 
				
			||||||
        self.assertClose(faces[0], faces2[0])
 | 
					        self.assertClose(faces2[0], faces[0])
 | 
				
			||||||
 | 
					        verts3, faces3 = marching_cubes(u.cuda(), 0, return_local_coords=False)
 | 
				
			||||||
 | 
					        self.assertEqual(len(verts3), len(verts))
 | 
				
			||||||
 | 
					        self.assertEqual(len(faces3), len(faces))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
    def marching_cubes_with_init(algo_type: str, batch_size: int, V: int, device: str):
 | 
					    def marching_cubes_with_init(algo_type: str, batch_size: int, V: int, device: str):
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user