mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	use less thrust, maybe help Windows
Summary: I think we include more thrust than needed, and maybe removing it will help things like https://github.com/facebookresearch/pytorch3d/issues/1610 with DebugSyncStream errors on Windows. Reviewed By: shapovalov Differential Revision: D48949888 fbshipit-source-id: add889c0acf730a039dc9ffd6bbcc24ded20ef27
This commit is contained in:
		
							parent
							
								
									a3d99cab6b
								
							
						
					
					
						commit
						6f2212da46
					
				@ -12,8 +12,6 @@
 | 
			
		||||
#include <math.h>
 | 
			
		||||
#include <stdio.h>
 | 
			
		||||
#include <stdlib.h>
 | 
			
		||||
#include <thrust/device_vector.h>
 | 
			
		||||
#include <thrust/tuple.h>
 | 
			
		||||
#include "iou_box3d/iou_utils.cuh"
 | 
			
		||||
 | 
			
		||||
// Parallelize over N*M computations which can each be done
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,6 @@
 | 
			
		||||
 | 
			
		||||
#include <float.h>
 | 
			
		||||
#include <math.h>
 | 
			
		||||
#include <thrust/device_vector.h>
 | 
			
		||||
#include <cstdio>
 | 
			
		||||
#include "utils/float_math.cuh"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -9,8 +9,6 @@
 | 
			
		||||
#include <ATen/ATen.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
#include <thrust/device_vector.h>
 | 
			
		||||
#include <thrust/scan.h>
 | 
			
		||||
#include <cstdio>
 | 
			
		||||
#include "marching_cubes/tables.h"
 | 
			
		||||
 | 
			
		||||
@ -40,20 +38,6 @@ through" each cube in the grid.
 | 
			
		||||
// EPS: Used to indicate if two float values are close
 | 
			
		||||
__constant__ const float EPSILON = 1e-5;
 | 
			
		||||
 | 
			
		||||
// Thrust wrapper for exclusive scan
 | 
			
		||||
//
 | 
			
		||||
// Args:
 | 
			
		||||
//    output: pointer to on-device output array
 | 
			
		||||
//    input: pointer to on-device input array, where scan is performed
 | 
			
		||||
//    numElements: number of elements for the input array
 | 
			
		||||
//
 | 
			
		||||
void ThrustScanWrapper(int* output, int* input, int numElements) {
 | 
			
		||||
  thrust::exclusive_scan(
 | 
			
		||||
      thrust::device_ptr<int>(input),
 | 
			
		||||
      thrust::device_ptr<int>(input + numElements),
 | 
			
		||||
      thrust::device_ptr<int>(output));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Linearly interpolate the position where an isosurface cuts an edge
 | 
			
		||||
// between two vertices, based on their scalar values
 | 
			
		||||
//
 | 
			
		||||
@ -455,19 +439,24 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
 | 
			
		||||
    grid.x = 65535;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  using at::indexing::None;
 | 
			
		||||
  using at::indexing::Slice;
 | 
			
		||||
 | 
			
		||||
  auto d_voxelVerts =
 | 
			
		||||
      at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
 | 
			
		||||
      at::zeros({numVoxels + 1}, at::TensorOptions().dtype(at::kInt))
 | 
			
		||||
          .to(vol.device());
 | 
			
		||||
  auto d_voxelVerts_ = d_voxelVerts.index({Slice(1, None)});
 | 
			
		||||
  auto d_voxelOccupied =
 | 
			
		||||
      at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
 | 
			
		||||
      at::zeros({numVoxels + 1}, at::TensorOptions().dtype(at::kInt))
 | 
			
		||||
          .to(vol.device());
 | 
			
		||||
  auto d_voxelOccupied_ = d_voxelOccupied.index({Slice(1, None)});
 | 
			
		||||
 | 
			
		||||
  // Execute "ClassifyVoxelKernel" kernel to precompute
 | 
			
		||||
  // two arrays - d_voxelOccupied and d_voxelVertices to global memory,
 | 
			
		||||
  // which stores the occupancy state and number of voxel vertices per voxel.
 | 
			
		||||
  ClassifyVoxelKernel<<<grid, threads, 0, stream>>>(
 | 
			
		||||
      d_voxelVerts.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelVerts_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelOccupied_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
 | 
			
		||||
      isolevel);
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
@ -477,18 +466,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
 | 
			
		||||
  // count for voxels in the grid and compute the number of active voxels.
 | 
			
		||||
  // If the number of active voxels is 0, return zero tensor for verts and
 | 
			
		||||
  // faces.
 | 
			
		||||
  auto d_voxelOccupiedScan =
 | 
			
		||||
      at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
 | 
			
		||||
          .to(vol.device());
 | 
			
		||||
  ThrustScanWrapper(
 | 
			
		||||
      d_voxelOccupiedScan.data_ptr<int>(),
 | 
			
		||||
      d_voxelOccupied.data_ptr<int>(),
 | 
			
		||||
      numVoxels);
 | 
			
		||||
 | 
			
		||||
  auto d_voxelOccupiedScan = at::cumsum(d_voxelOccupied, 0);
 | 
			
		||||
  auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)});
 | 
			
		||||
 | 
			
		||||
  // number of active voxels
 | 
			
		||||
  int lastElement = d_voxelVerts[numVoxels - 1].cpu().item<int>();
 | 
			
		||||
  int lastScan = d_voxelOccupiedScan[numVoxels - 1].cpu().item<int>();
 | 
			
		||||
  int activeVoxels = lastElement + lastScan;
 | 
			
		||||
  int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<int>();
 | 
			
		||||
 | 
			
		||||
  const int device_id = vol.device().index();
 | 
			
		||||
  auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
 | 
			
		||||
@ -509,22 +492,17 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
 | 
			
		||||
  CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
 | 
			
		||||
      d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelOccupiedScan.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelOccupiedScan_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      numVoxels);
 | 
			
		||||
  AT_CUDA_CHECK(cudaGetLastError());
 | 
			
		||||
  cudaDeviceSynchronize();
 | 
			
		||||
 | 
			
		||||
  // Scan d_voxelVerts array to generate offsets of vertices for each voxel
 | 
			
		||||
  auto d_voxelVertsScan = at::zeros({numVoxels}, opt);
 | 
			
		||||
  ThrustScanWrapper(
 | 
			
		||||
      d_voxelVertsScan.data_ptr<int>(),
 | 
			
		||||
      d_voxelVerts.data_ptr<int>(),
 | 
			
		||||
      numVoxels);
 | 
			
		||||
  auto d_voxelVertsScan = at::cumsum(d_voxelVerts, 0);
 | 
			
		||||
  auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)});
 | 
			
		||||
 | 
			
		||||
  // total number of vertices
 | 
			
		||||
  lastElement = d_voxelVerts[numVoxels - 1].cpu().item<int>();
 | 
			
		||||
  lastScan = d_voxelVertsScan[numVoxels - 1].cpu().item<int>();
 | 
			
		||||
  int totalVerts = lastElement + lastScan;
 | 
			
		||||
  int totalVerts = d_voxelVertsScan[numVoxels].cpu().item<int>();
 | 
			
		||||
 | 
			
		||||
  // Execute "GenerateFacesKernel" kernel
 | 
			
		||||
  // This runs only on the occupied voxels.
 | 
			
		||||
@ -544,7 +522,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
 | 
			
		||||
      faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
 | 
			
		||||
      ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelVertsScan.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      d_voxelVertsScan_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
 | 
			
		||||
      activeVoxels,
 | 
			
		||||
      vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
 | 
			
		||||
      faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user