mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Fix windows build (#1689)
Summary: Change the data type usage in the code to ensure cross-platform compatibility long -> int64_t <img width="628" alt="image" src="https://github.com/facebookresearch/pytorch3d/assets/6214316/40041f7f-3c09-4571-b9ff-676c625802e9"> Tested under Win 11 and Ubuntu 22.04 with CUDA 12.1.1 torch 2.1.1 Related issues & PR https://github.com/facebookresearch/pytorch3d/pull/9 https://github.com/facebookresearch/pytorch3d/issues/1679 Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/1689 Reviewed By: MichaelRamamonjisoa Differential Revision: D51521562 Pulled By: bottler fbshipit-source-id: d8ea81e223c642e0e9fb283f5d7efc9d6ac00d93
This commit is contained in:
		
							parent
							
								
									83bacda8fb
								
							
						
					
					
						commit
						7606854ff7
					
				@ -223,7 +223,7 @@ __global__ void CompactVoxelsKernel(
 | 
			
		||||
        compactedVoxelArray,
 | 
			
		||||
    const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
 | 
			
		||||
        voxelOccupied,
 | 
			
		||||
    const at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits>
 | 
			
		||||
    const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
 | 
			
		||||
        voxelOccupiedScan,
 | 
			
		||||
    uint numVoxels) {
 | 
			
		||||
  uint id = blockIdx.x * blockDim.x + threadIdx.x;
 | 
			
		||||
@ -255,7 +255,8 @@ __global__ void GenerateFacesKernel(
 | 
			
		||||
    at::PackedTensorAccessor<int64_t, 1, at::RestrictPtrTraits> ids,
 | 
			
		||||
    at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
 | 
			
		||||
        compactedVoxelArray,
 | 
			
		||||
    at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits> numVertsScanned,
 | 
			
		||||
    at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
 | 
			
		||||
        numVertsScanned,
 | 
			
		||||
    const uint activeVoxels,
 | 
			
		||||
    const at::PackedTensorAccessor32<float, 3, at::RestrictPtrTraits> vol,
 | 
			
		||||
    const at::PackedTensorAccessor32<int, 2, at::RestrictPtrTraits> faceTable,
 | 
			
		||||
@ -471,7 +472,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
 | 
			
		||||
  auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)});
 | 
			
		||||
 | 
			
		||||
  // number of active voxels
 | 
			
		||||
  int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<long>();
 | 
			
		||||
  int64_t activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<int64_t>();
 | 
			
		||||
 | 
			
		||||
  const int device_id = vol.device().index();
 | 
			
		||||
  auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
 | 
			
		||||
@ -492,7 +493,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
 | 
			
		||||
  CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
 | 
			
		||||
      d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelOccupiedScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelOccupiedScan_
 | 
			
		||||
          .packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      numVoxels);
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
  cudaDeviceSynchronize();
 | 
			
		||||
@ -502,7 +504,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
 | 
			
		||||
  auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)});
 | 
			
		||||
 | 
			
		||||
  // total number of vertices
 | 
			
		||||
  int totalVerts = d_voxelVertsScan[numVoxels].cpu().item<long>();
 | 
			
		||||
  int64_t totalVerts = d_voxelVertsScan[numVoxels].cpu().item<int64_t>();
 | 
			
		||||
 | 
			
		||||
  // Execute "GenerateFacesKernel" kernel
 | 
			
		||||
  // This runs only on the occupied voxels.
 | 
			
		||||
@ -522,7 +524,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
 | 
			
		||||
      faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
 | 
			
		||||
      ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelVertsScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelVertsScan_.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      activeVoxels,
 | 
			
		||||
      vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
 | 
			
		||||
      faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user