mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	CUDA marching_cubes fix
Summary: Fix an inclusive vs exclusive scan mix-up that was accidentally introduced when removing the Thrust dependency (`Thrust::exclusive_scan`) and reimplementing it using `at::cumsum` (which does an inclusive scan). This fixes two Github reported issues: * https://github.com/facebookresearch/pytorch3d/issues/1731 * https://github.com/facebookresearch/pytorch3d/issues/1751 Reviewed By: bottler Differential Revision: D54605545 fbshipit-source-id: da9e92f3f8a9a35f7b7191428d0b9a9ca03e0d4d
This commit is contained in:
		
							parent
							
								
									a27755db41
								
							
						
					
					
						commit
						7566530669
					
				@ -382,6 +382,44 @@ __global__ void GenerateFacesKernel(
 | 
			
		||||
  } // end for grid-strided kernel
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ATen/Torch does not have an exclusive-scan operator. Additionally, in the
 | 
			
		||||
// code below we need to get the "total number of items to work on" after
 | 
			
		||||
// a scan, which with an inclusive-scan would simply be the value of the last
 | 
			
		||||
// element in the tensor.
 | 
			
		||||
//
 | 
			
		||||
// This utility function hits two birds with one stone, by running
 | 
			
		||||
// an inclusive-scan into a right-shifted view of a tensor that's
 | 
			
		||||
// allocated to be one element bigger than the input tensor.
 | 
			
		||||
//
 | 
			
		||||
// Note; return tensor is `int64_t` per element, even if the input
 | 
			
		||||
// tensor is only 32-bit. Also, the return tensor is one element bigger
 | 
			
		||||
// than the input one.
 | 
			
		||||
//
 | 
			
		||||
// Secondary optional argument is an output argument that gets the
 | 
			
		||||
// value of the last element of the return tensor (because you almost
 | 
			
		||||
// always need this CPU-side right after this function anyway).
 | 
			
		||||
static at::Tensor ExclusiveScanAndTotal(
 | 
			
		||||
    const at::Tensor& inTensor,
 | 
			
		||||
    int64_t* optTotal = nullptr) {
 | 
			
		||||
  const auto inSize = inTensor.sizes()[0];
 | 
			
		||||
  auto retTensor = at::zeros({inSize + 1}, at::kLong).to(inTensor.device());
 | 
			
		||||
 | 
			
		||||
  using at::indexing::None;
 | 
			
		||||
  using at::indexing::Slice;
 | 
			
		||||
  auto rightShiftedView = retTensor.index({Slice(1, None)});
 | 
			
		||||
 | 
			
		||||
  // Do an (inclusive-scan) cumulative sum in to the view that's
 | 
			
		||||
  // shifted one element to the right...
 | 
			
		||||
  at::cumsum_out(rightShiftedView, inTensor, 0, at::kLong);
 | 
			
		||||
 | 
			
		||||
  if (optTotal) {
 | 
			
		||||
    *optTotal = retTensor[inSize].cpu().item<int64_t>();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // ...so that the not-shifted tensor holds the exclusive-scan
 | 
			
		||||
  return retTensor;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Entrance for marching cubes cuda extension. Marching Cubes is an algorithm to
 | 
			
		||||
// create triangle meshes from an implicit function (one of the form f(x, y, z)
 | 
			
		||||
// = 0). It works by iteratively checking a grid of cubes superimposed over a
 | 
			
		||||
@ -444,20 +482,18 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
 | 
			
		||||
  using at::indexing::Slice;
 | 
			
		||||
 | 
			
		||||
  auto d_voxelVerts =
 | 
			
		||||
      at::zeros({numVoxels + 1}, at::TensorOptions().dtype(at::kInt))
 | 
			
		||||
      at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
 | 
			
		||||
          .to(vol.device());
 | 
			
		||||
  auto d_voxelVerts_ = d_voxelVerts.index({Slice(1, None)});
 | 
			
		||||
  auto d_voxelOccupied =
 | 
			
		||||
      at::zeros({numVoxels + 1}, at::TensorOptions().dtype(at::kInt))
 | 
			
		||||
      at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
 | 
			
		||||
          .to(vol.device());
 | 
			
		||||
  auto d_voxelOccupied_ = d_voxelOccupied.index({Slice(1, None)});
 | 
			
		||||
 | 
			
		||||
  // Execute "ClassifyVoxelKernel" kernel to precompute
 | 
			
		||||
  // two arrays - d_voxelOccupied and d_voxelVertices to global memory,
 | 
			
		||||
  // which stores the occupancy state and number of voxel vertices per voxel.
 | 
			
		||||
  ClassifyVoxelKernel<<<grid, threads, 0, stream>>>(
 | 
			
		||||
      d_voxelVerts_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelOccupied_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelVerts.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
 | 
			
		||||
      isolevel);
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
@ -467,12 +503,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
 | 
			
		||||
  // count for voxels in the grid and compute the number of active voxels.
 | 
			
		||||
  // If the number of active voxels is 0, return zero tensor for verts and
 | 
			
		||||
  // faces.
 | 
			
		||||
 | 
			
		||||
  auto d_voxelOccupiedScan = at::cumsum(d_voxelOccupied, 0);
 | 
			
		||||
  auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)});
 | 
			
		||||
 | 
			
		||||
  // number of active voxels
 | 
			
		||||
  int64_t activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<int64_t>();
 | 
			
		||||
  int64_t activeVoxels = 0;
 | 
			
		||||
  auto d_voxelOccupiedScan =
 | 
			
		||||
      ExclusiveScanAndTotal(d_voxelOccupied, &activeVoxels);
 | 
			
		||||
 | 
			
		||||
  const int device_id = vol.device().index();
 | 
			
		||||
  auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
 | 
			
		||||
@ -487,24 +520,21 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
 | 
			
		||||
    return std::make_tuple(verts, faces, ids);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Execute "CompactVoxelsKernel" kernel to compress voxels for accleration.
 | 
			
		||||
  // Execute "CompactVoxelsKernel" kernel to compress voxels for acceleration.
 | 
			
		||||
  // This allows us to run triangle generation on only the occupied voxels.
 | 
			
		||||
  auto d_compVoxelArray = at::zeros({activeVoxels}, opt);
 | 
			
		||||
  CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
 | 
			
		||||
      d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelOccupiedScan_
 | 
			
		||||
      d_voxelOccupiedScan
 | 
			
		||||
          .packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      numVoxels);
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
  cudaDeviceSynchronize();
 | 
			
		||||
 | 
			
		||||
  // Scan d_voxelVerts array to generate offsets of vertices for each voxel
 | 
			
		||||
  auto d_voxelVertsScan = at::cumsum(d_voxelVerts, 0);
 | 
			
		||||
  auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)});
 | 
			
		||||
 | 
			
		||||
  // total number of vertices
 | 
			
		||||
  int64_t totalVerts = d_voxelVertsScan[numVoxels].cpu().item<int64_t>();
 | 
			
		||||
  int64_t totalVerts = 0;
 | 
			
		||||
  auto d_voxelVertsScan = ExclusiveScanAndTotal(d_voxelVerts, &totalVerts);
 | 
			
		||||
 | 
			
		||||
  // Execute "GenerateFacesKernel" kernel
 | 
			
		||||
  // This runs only on the occupied voxels.
 | 
			
		||||
@ -524,7 +554,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
 | 
			
		||||
      faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
 | 
			
		||||
      ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelVertsScan_.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelVertsScan.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      activeVoxels,
 | 
			
		||||
      vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
 | 
			
		||||
      faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user