mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-01 03:12:49 +08:00
use less thrust, maybe help Windows
Summary: I think we include more thrust than needed, and maybe removing it will help things like https://github.com/facebookresearch/pytorch3d/issues/1610 with DebugSyncStream errors on Windows. Reviewed By: shapovalov Differential Revision: D48949888 fbshipit-source-id: add889c0acf730a039dc9ffd6bbcc24ded20ef27
This commit is contained in:
parent
a3d99cab6b
commit
6f2212da46
@ -12,8 +12,6 @@
|
||||
#include <math.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/tuple.h>
|
||||
#include "iou_box3d/iou_utils.cuh"
|
||||
|
||||
// Parallelize over N*M computations which can each be done
|
||||
|
@ -8,7 +8,6 @@
|
||||
|
||||
#include <float.h>
|
||||
#include <math.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <cstdio>
|
||||
#include "utils/float_math.cuh"
|
||||
|
||||
|
@ -9,8 +9,6 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/scan.h>
|
||||
#include <cstdio>
|
||||
#include "marching_cubes/tables.h"
|
||||
|
||||
@ -40,20 +38,6 @@ through" each cube in the grid.
|
||||
// EPS: Used to indicate if two float values are close
|
||||
__constant__ const float EPSILON = 1e-5;
|
||||
|
||||
// Thrust wrapper for exclusive scan
|
||||
//
|
||||
// Args:
|
||||
// output: pointer to on-device output array
|
||||
// input: pointer to on-device input array, where scan is performed
|
||||
// numElements: number of elements for the input array
|
||||
//
|
||||
void ThrustScanWrapper(int* output, int* input, int numElements) {
|
||||
thrust::exclusive_scan(
|
||||
thrust::device_ptr<int>(input),
|
||||
thrust::device_ptr<int>(input + numElements),
|
||||
thrust::device_ptr<int>(output));
|
||||
}
|
||||
|
||||
// Linearly interpolate the position where an isosurface cuts an edge
|
||||
// between two vertices, based on their scalar values
|
||||
//
|
||||
@ -455,19 +439,24 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
|
||||
grid.x = 65535;
|
||||
}
|
||||
|
||||
using at::indexing::None;
|
||||
using at::indexing::Slice;
|
||||
|
||||
auto d_voxelVerts =
|
||||
at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
|
||||
at::zeros({numVoxels + 1}, at::TensorOptions().dtype(at::kInt))
|
||||
.to(vol.device());
|
||||
auto d_voxelVerts_ = d_voxelVerts.index({Slice(1, None)});
|
||||
auto d_voxelOccupied =
|
||||
at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
|
||||
at::zeros({numVoxels + 1}, at::TensorOptions().dtype(at::kInt))
|
||||
.to(vol.device());
|
||||
auto d_voxelOccupied_ = d_voxelOccupied.index({Slice(1, None)});
|
||||
|
||||
// Execute "ClassifyVoxelKernel" kernel to precompute
|
||||
// two arrays - d_voxelOccupied and d_voxelVertices to global memory,
|
||||
// which stores the occupancy state and number of voxel vertices per voxel.
|
||||
ClassifyVoxelKernel<<<grid, threads, 0, stream>>>(
|
||||
d_voxelVerts.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
d_voxelVerts_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
d_voxelOccupied_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
|
||||
isolevel);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
@ -477,18 +466,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
|
||||
// count for voxels in the grid and compute the number of active voxels.
|
||||
// If the number of active voxels is 0, return zero tensor for verts and
|
||||
// faces.
|
||||
auto d_voxelOccupiedScan =
|
||||
at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
|
||||
.to(vol.device());
|
||||
ThrustScanWrapper(
|
||||
d_voxelOccupiedScan.data_ptr<int>(),
|
||||
d_voxelOccupied.data_ptr<int>(),
|
||||
numVoxels);
|
||||
|
||||
auto d_voxelOccupiedScan = at::cumsum(d_voxelOccupied, 0);
|
||||
auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)});
|
||||
|
||||
// number of active voxels
|
||||
int lastElement = d_voxelVerts[numVoxels - 1].cpu().item<int>();
|
||||
int lastScan = d_voxelOccupiedScan[numVoxels - 1].cpu().item<int>();
|
||||
int activeVoxels = lastElement + lastScan;
|
||||
int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<int>();
|
||||
|
||||
const int device_id = vol.device().index();
|
||||
auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
|
||||
@ -509,22 +492,17 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
|
||||
CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
|
||||
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
d_voxelOccupiedScan.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
d_voxelOccupiedScan_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
numVoxels);
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
// Scan d_voxelVerts array to generate offsets of vertices for each voxel
|
||||
auto d_voxelVertsScan = at::zeros({numVoxels}, opt);
|
||||
ThrustScanWrapper(
|
||||
d_voxelVertsScan.data_ptr<int>(),
|
||||
d_voxelVerts.data_ptr<int>(),
|
||||
numVoxels);
|
||||
auto d_voxelVertsScan = at::cumsum(d_voxelVerts, 0);
|
||||
auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)});
|
||||
|
||||
// total number of vertices
|
||||
lastElement = d_voxelVerts[numVoxels - 1].cpu().item<int>();
|
||||
lastScan = d_voxelVertsScan[numVoxels - 1].cpu().item<int>();
|
||||
int totalVerts = lastElement + lastScan;
|
||||
int totalVerts = d_voxelVertsScan[numVoxels].cpu().item<int>();
|
||||
|
||||
// Execute "GenerateFacesKernel" kernel
|
||||
// This runs only on the occupied voxels.
|
||||
@ -544,7 +522,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
|
||||
faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
|
||||
ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
|
||||
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
d_voxelVertsScan.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
d_voxelVertsScan_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
|
||||
activeVoxels,
|
||||
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
|
||||
faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),
|
||||
|
Loading…
x
Reference in New Issue
Block a user