mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
CUB usage fix for sample_farthest_points
Summary: Fix for https://github.com/facebookresearch/pytorch3d/issues/1529 Reviewed By: shapovalov Differential Revision: D45569211 fbshipit-source-id: 8c485f26cd409cafac53d4d982a03cde81a1d853
This commit is contained in:
parent
c8d6cd427e
commit
b921efae3e
@ -155,7 +155,7 @@ at::Tensor FarthestPointSamplingCuda(
|
|||||||
|
|
||||||
// Max possible threads per block
|
// Max possible threads per block
|
||||||
const int MAX_THREADS_PER_BLOCK = 1024;
|
const int MAX_THREADS_PER_BLOCK = 1024;
|
||||||
const size_t threads = max(min(1 << points_pow_2, MAX_THREADS_PER_BLOCK), 1);
|
const size_t threads = max(min(1 << points_pow_2, MAX_THREADS_PER_BLOCK), 2);
|
||||||
|
|
||||||
// Create the accessors
|
// Create the accessors
|
||||||
auto points_a = points.packed_accessor64<float, 3, at::RestrictPtrTraits>();
|
auto points_a = points.packed_accessor64<float, 3, at::RestrictPtrTraits>();
|
||||||
@ -215,10 +215,6 @@ at::Tensor FarthestPointSamplingCuda(
|
|||||||
FarthestPointSamplingKernel<2><<<threads, threads, shared_mem, stream>>>(
|
FarthestPointSamplingKernel<2><<<threads, threads, shared_mem, stream>>>(
|
||||||
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
|
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
|
||||||
break;
|
break;
|
||||||
case 1:
|
|
||||||
FarthestPointSamplingKernel<1><<<threads, threads, shared_mem, stream>>>(
|
|
||||||
points_a, lengths_a, K_a, idxs_a, min_point_dist_a, start_idxs_a);
|
|
||||||
break;
|
|
||||||
default:
|
default:
|
||||||
FarthestPointSamplingKernel<1024>
|
FarthestPointSamplingKernel<1024>
|
||||||
<<<blocks, threads, shared_mem, stream>>>(
|
<<<blocks, threads, shared_mem, stream>>>(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user