From 3987612062f3db5dba609df3552768dcd97b410f Mon Sep 17 00:00:00 2001 From: Richard Barnes Date: Wed, 19 Mar 2025 13:21:43 -0700 Subject: [PATCH] Fix CUDA kernel index data type in vision/fair/pytorch3d/pytorch3d/csrc/compositing/alpha_composite.cu +10 Summary: CUDA kernel variables matching the type `(thread|block|grid).(Idx|Dim).(x|y|z)` [have the data type `uint`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#built-in-variables). Many programmers mistakenly use implicit casts to turn these data types into `int`. In fact, the [CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/) it self is inconsistent and incorrect in its use of data types in programming examples. The result of these implicit casts is that our kernels may give unexpected results when exposed to large datasets, i.e., those exceeding >~2B items. While we now have linters in place to prevent simple mistakes (D71236150), our codebase has many problematic instances. This diff fixes some of them. Reviewed By: dtolnay Differential Revision: D71355356 fbshipit-source-id: cea44891416d9efd2f466d6c45df4e36008fa036 --- pytorch3d/csrc/compositing/alpha_composite.cu | 12 ++++++------ .../csrc/compositing/norm_weighted_sum.cu | 12 ++++++------ pytorch3d/csrc/compositing/weighted_sum.cu | 12 ++++++------ .../csrc/gather_scatter/gather_scatter.cu | 6 +++--- .../interp_face_attrs/interp_face_attrs.cu | 8 ++++---- pytorch3d/csrc/point_mesh/point_mesh_cuda.cu | 18 +++++++++--------- pytorch3d/csrc/rasterize_coarse/bitmask.cuh | 2 +- .../csrc/rasterize_coarse/rasterize_coarse.cu | 14 +++++++------- .../csrc/rasterize_meshes/rasterize_meshes.cu | 12 ++++++------ .../csrc/rasterize_points/rasterize_points.cu | 12 ++++++------ 10 files changed, 54 insertions(+), 54 deletions(-) diff --git a/pytorch3d/csrc/compositing/alpha_composite.cu b/pytorch3d/csrc/compositing/alpha_composite.cu index b5d512e8..2bfe79dc 100644 --- a/pytorch3d/csrc/compositing/alpha_composite.cu +++ b/pytorch3d/csrc/compositing/alpha_composite.cu @@ -33,11 +33,11 @@ __global__ void alphaCompositeCudaForwardKernel( const int64_t W = points_idx.size(3); // Get the batch and index - const int batch = blockIdx.x; + const auto batch = blockIdx.x; const int num_pixels = C * H * W; - const int num_threads = gridDim.y * blockDim.x; - const int tid = blockIdx.y * blockDim.x + threadIdx.x; + const auto num_threads = gridDim.y * blockDim.x; + const auto tid = blockIdx.y * blockDim.x + threadIdx.x; // Iterate over each feature in each pixel for (int pid = tid; pid < num_pixels; pid += num_threads) { @@ -83,11 +83,11 @@ __global__ void alphaCompositeCudaBackwardKernel( const int64_t W = points_idx.size(3); // Get the batch and index - const int batch = blockIdx.x; + const auto batch = blockIdx.x; const int num_pixels = C * H * W; - const int num_threads = gridDim.y * blockDim.x; - const int tid = blockIdx.y * blockDim.x + threadIdx.x; + const auto num_threads = gridDim.y * blockDim.x; + const auto tid = blockIdx.y * blockDim.x + threadIdx.x; // Parallelize over each feature in each pixel in images of size H * W, // for each image in the batch of size batch_size diff --git a/pytorch3d/csrc/compositing/norm_weighted_sum.cu b/pytorch3d/csrc/compositing/norm_weighted_sum.cu index 455bdb7f..e21617d2 100644 --- a/pytorch3d/csrc/compositing/norm_weighted_sum.cu +++ b/pytorch3d/csrc/compositing/norm_weighted_sum.cu @@ -33,11 +33,11 @@ __global__ void weightedSumNormCudaForwardKernel( const int64_t W = points_idx.size(3); // Get the batch and index - const int batch = blockIdx.x; + const auto batch = blockIdx.x; const int num_pixels = C * H * W; - const int num_threads = gridDim.y * blockDim.x; - const int tid = blockIdx.y * blockDim.x + threadIdx.x; + const auto num_threads = gridDim.y * blockDim.x; + const auto tid = blockIdx.y * blockDim.x + threadIdx.x; // Parallelize over each feature in each pixel in images of size H * W, // for each image in the batch of size batch_size @@ -96,11 +96,11 @@ __global__ void weightedSumNormCudaBackwardKernel( const int64_t W = points_idx.size(3); // Get the batch and index - const int batch = blockIdx.x; + const auto batch = blockIdx.x; const int num_pixels = C * W * H; - const int num_threads = gridDim.y * blockDim.x; - const int tid = blockIdx.y * blockDim.x + threadIdx.x; + const auto num_threads = gridDim.y * blockDim.x; + const auto tid = blockIdx.y * blockDim.x + threadIdx.x; // Parallelize over each feature in each pixel in images of size H * W, // for each image in the batch of size batch_size diff --git a/pytorch3d/csrc/compositing/weighted_sum.cu b/pytorch3d/csrc/compositing/weighted_sum.cu index 125688a1..2e0904e7 100644 --- a/pytorch3d/csrc/compositing/weighted_sum.cu +++ b/pytorch3d/csrc/compositing/weighted_sum.cu @@ -31,11 +31,11 @@ __global__ void weightedSumCudaForwardKernel( const int64_t W = points_idx.size(3); // Get the batch and index - const int batch = blockIdx.x; + const auto batch = blockIdx.x; const int num_pixels = C * H * W; - const int num_threads = gridDim.y * blockDim.x; - const int tid = blockIdx.y * blockDim.x + threadIdx.x; + const auto num_threads = gridDim.y * blockDim.x; + const auto tid = blockIdx.y * blockDim.x + threadIdx.x; // Parallelize over each feature in each pixel in images of size H * W, // for each image in the batch of size batch_size @@ -78,11 +78,11 @@ __global__ void weightedSumCudaBackwardKernel( const int64_t W = points_idx.size(3); // Get the batch and index - const int batch = blockIdx.x; + const auto batch = blockIdx.x; const int num_pixels = C * H * W; - const int num_threads = gridDim.y * blockDim.x; - const int tid = blockIdx.y * blockDim.x + threadIdx.x; + const auto num_threads = gridDim.y * blockDim.x; + const auto tid = blockIdx.y * blockDim.x + threadIdx.x; // Iterate over each pixel to compute the contribution to the // gradient for the features and weights diff --git a/pytorch3d/csrc/gather_scatter/gather_scatter.cu b/pytorch3d/csrc/gather_scatter/gather_scatter.cu index 1ec1a6f2..d4affd4b 100644 --- a/pytorch3d/csrc/gather_scatter/gather_scatter.cu +++ b/pytorch3d/csrc/gather_scatter/gather_scatter.cu @@ -20,14 +20,14 @@ __global__ void GatherScatterCudaKernel( const size_t V, const size_t D, const size_t E) { - const int tid = threadIdx.x; + const auto tid = threadIdx.x; // Reverse the vertex order if backward. const int v0_idx = backward ? 1 : 0; const int v1_idx = backward ? 0 : 1; // Edges are split evenly across the blocks. - for (int e = blockIdx.x; e < E; e += gridDim.x) { + for (auto e = blockIdx.x; e < E; e += gridDim.x) { // Get indices of vertices which form the edge. const int64_t v0 = edges[2 * e + v0_idx]; const int64_t v1 = edges[2 * e + v1_idx]; @@ -35,7 +35,7 @@ __global__ void GatherScatterCudaKernel( // Split vertex features evenly across threads. // This implementation will be quite wasteful when D<128 since there will be // a lot of threads doing nothing. - for (int d = tid; d < D; d += blockDim.x) { + for (auto d = tid; d < D; d += blockDim.x) { const float val = input[v1 * D + d]; float* address = output + v0 * D + d; atomicAdd(address, val); diff --git a/pytorch3d/csrc/interp_face_attrs/interp_face_attrs.cu b/pytorch3d/csrc/interp_face_attrs/interp_face_attrs.cu index 6bd2a80d..8fe292ae 100644 --- a/pytorch3d/csrc/interp_face_attrs/interp_face_attrs.cu +++ b/pytorch3d/csrc/interp_face_attrs/interp_face_attrs.cu @@ -20,8 +20,8 @@ __global__ void InterpFaceAttrsForwardKernel( const size_t P, const size_t F, const size_t D) { - const int tid = threadIdx.x + blockIdx.x * blockDim.x; - const int num_threads = blockDim.x * gridDim.x; + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + const auto num_threads = blockDim.x * gridDim.x; for (int pd = tid; pd < P * D; pd += num_threads) { const int p = pd / D; const int d = pd % D; @@ -93,8 +93,8 @@ __global__ void InterpFaceAttrsBackwardKernel( const size_t P, const size_t F, const size_t D) { - const int tid = threadIdx.x + blockIdx.x * blockDim.x; - const int num_threads = blockDim.x * gridDim.x; + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + const auto num_threads = blockDim.x * gridDim.x; for (int pd = tid; pd < P * D; pd += num_threads) { const int p = pd / D; const int d = pd % D; diff --git a/pytorch3d/csrc/point_mesh/point_mesh_cuda.cu b/pytorch3d/csrc/point_mesh/point_mesh_cuda.cu index 3788d405..606ec9e6 100644 --- a/pytorch3d/csrc/point_mesh/point_mesh_cuda.cu +++ b/pytorch3d/csrc/point_mesh/point_mesh_cuda.cu @@ -110,7 +110,7 @@ __global__ void DistanceForwardKernel( __syncthreads(); // Perform reduction in shared memory. - for (int s = blockDim.x / 2; s > 32; s >>= 1) { + for (auto s = blockDim.x / 2; s > 32; s >>= 1) { if (tid < s) { if (min_dists[tid] > min_dists[tid + s]) { min_dists[tid] = min_dists[tid + s]; @@ -502,8 +502,8 @@ __global__ void PointFaceArrayForwardKernel( const float3* tris_f3 = (float3*)tris; // Parallelize over P * S computations - const int num_threads = gridDim.x * blockDim.x; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto num_threads = gridDim.x * blockDim.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; for (int t_i = tid; t_i < P * T; t_i += num_threads) { const int t = t_i / P; // segment index. @@ -576,8 +576,8 @@ __global__ void PointFaceArrayBackwardKernel( const float3* tris_f3 = (float3*)tris; // Parallelize over P * S computations - const int num_threads = gridDim.x * blockDim.x; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto num_threads = gridDim.x * blockDim.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; for (int t_i = tid; t_i < P * T; t_i += num_threads) { const int t = t_i / P; // triangle index. @@ -683,8 +683,8 @@ __global__ void PointEdgeArrayForwardKernel( float3* segms_f3 = (float3*)segms; // Parallelize over P * S computations - const int num_threads = gridDim.x * blockDim.x; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto num_threads = gridDim.x * blockDim.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; for (int t_i = tid; t_i < P * S; t_i += num_threads) { const int s = t_i / P; // segment index. @@ -752,8 +752,8 @@ __global__ void PointEdgeArrayBackwardKernel( float3* segms_f3 = (float3*)segms; // Parallelize over P * S computations - const int num_threads = gridDim.x * blockDim.x; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto num_threads = gridDim.x * blockDim.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; for (int t_i = tid; t_i < P * S; t_i += num_threads) { const int s = t_i / P; // segment index. diff --git a/pytorch3d/csrc/rasterize_coarse/bitmask.cuh b/pytorch3d/csrc/rasterize_coarse/bitmask.cuh index 6ffcac87..729650ba 100644 --- a/pytorch3d/csrc/rasterize_coarse/bitmask.cuh +++ b/pytorch3d/csrc/rasterize_coarse/bitmask.cuh @@ -25,7 +25,7 @@ class BitMask { // Use all threads in the current block to clear all bits of this BitMask __device__ void block_clear() { - for (int i = threadIdx.x; i < H * W * D; i += blockDim.x) { + for (auto i = threadIdx.x; i < H * W * D; i += blockDim.x) { data[i] = 0; } __syncthreads(); diff --git a/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu b/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu index aed57d21..f093ef05 100644 --- a/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu +++ b/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu @@ -23,8 +23,8 @@ __global__ void TriangleBoundingBoxKernel( const float blur_radius, float* bboxes, // (4, F) bool* skip_face) { // (F,) - const int tid = blockIdx.x * blockDim.x + threadIdx.x; - const int num_threads = blockDim.x * gridDim.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto num_threads = blockDim.x * gridDim.x; const float sqrt_radius = sqrt(blur_radius); for (int f = tid; f < F; f += num_threads) { const float v0x = face_verts[f * 9 + 0 * 3 + 0]; @@ -56,8 +56,8 @@ __global__ void PointBoundingBoxKernel( const int P, float* bboxes, // (4, P) bool* skip_points) { - const int tid = blockIdx.x * blockDim.x + threadIdx.x; - const int num_threads = blockDim.x * gridDim.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto num_threads = blockDim.x * gridDim.x; for (int p = tid; p < P; p += num_threads) { const float x = points[p * 3 + 0]; const float y = points[p * 3 + 1]; @@ -113,7 +113,7 @@ __global__ void RasterizeCoarseCudaKernel( const int chunks_per_batch = 1 + (E - 1) / chunk_size; const int num_chunks = N * chunks_per_batch; - for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) { + for (auto chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) { const int batch_idx = chunk / chunks_per_batch; // batch index const int chunk_idx = chunk % chunks_per_batch; const int elem_chunk_start_idx = chunk_idx * chunk_size; @@ -123,7 +123,7 @@ __global__ void RasterizeCoarseCudaKernel( const int64_t elem_stop_idx = elem_start_idx + elems_per_batch[batch_idx]; // Have each thread handle a different face within the chunk - for (int e = threadIdx.x; e < chunk_size; e += blockDim.x) { + for (auto e = threadIdx.x; e < chunk_size; e += blockDim.x) { const int e_idx = elem_chunk_start_idx + e; // Check that we are still within the same element of the batch @@ -170,7 +170,7 @@ __global__ void RasterizeCoarseCudaKernel( // Now we have processed every elem in the current chunk. We need to // count the number of elems in each bin so we can write the indices // out to global memory. We have each thread handle a different bin. - for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x; + for (auto byx = threadIdx.x; byx < num_bins_y * num_bins_x; byx += blockDim.x) { const int by = byx / num_bins_x; const int bx = byx % num_bins_x; diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu index 9dd3e266..28c546c6 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu @@ -260,8 +260,8 @@ __global__ void RasterizeMeshesNaiveCudaKernel( float* pix_dists, float* bary) { // Simple version: One thread per output pixel - int num_threads = gridDim.x * blockDim.x; - int tid = blockDim.x * blockIdx.x + threadIdx.x; + auto num_threads = gridDim.x * blockDim.x; + auto tid = blockDim.x * blockIdx.x + threadIdx.x; for (int i = tid; i < N * H * W; i += num_threads) { // Convert linear index to 3D index @@ -446,8 +446,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel( // Parallelize over each pixel in images of // size H * W, for each image in the batch of size N. - const int num_threads = gridDim.x * blockDim.x; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto num_threads = gridDim.x * blockDim.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; for (int t_i = tid; t_i < N * H * W; t_i += num_threads) { // Convert linear index to 3D index @@ -650,8 +650,8 @@ __global__ void RasterizeMeshesFineCudaKernel( ) { // This can be more than H * W if H or W are not divisible by bin_size. int num_pixels = N * BH * BW * bin_size * bin_size; - int num_threads = gridDim.x * blockDim.x; - int tid = blockIdx.x * blockDim.x + threadIdx.x; + auto num_threads = gridDim.x * blockDim.x; + auto tid = blockIdx.x * blockDim.x + threadIdx.x; for (int pid = tid; pid < num_pixels; pid += num_threads) { // Convert linear index into bin and pixel indices. We make the within diff --git a/pytorch3d/csrc/rasterize_points/rasterize_points.cu b/pytorch3d/csrc/rasterize_points/rasterize_points.cu index 5b18d833..20bf0de7 100644 --- a/pytorch3d/csrc/rasterize_points/rasterize_points.cu +++ b/pytorch3d/csrc/rasterize_points/rasterize_points.cu @@ -97,8 +97,8 @@ __global__ void RasterizePointsNaiveCudaKernel( float* zbuf, // (N, H, W, K) float* pix_dists) { // (N, H, W, K) // Simple version: One thread per output pixel - const int num_threads = gridDim.x * blockDim.x; - const int tid = blockDim.x * blockIdx.x + threadIdx.x; + const auto num_threads = gridDim.x * blockDim.x; + const auto tid = blockDim.x * blockIdx.x + threadIdx.x; for (int i = tid; i < N * H * W; i += num_threads) { // Convert linear index to 3D index const int n = i / (H * W); // Batch index @@ -237,8 +237,8 @@ __global__ void RasterizePointsFineCudaKernel( float* pix_dists) { // (N, H, W, K) // This can be more than H * W if H or W are not divisible by bin_size. const int num_pixels = N * BH * BW * bin_size * bin_size; - const int num_threads = gridDim.x * blockDim.x; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; + const auto num_threads = gridDim.x * blockDim.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; for (int pid = tid; pid < num_pixels; pid += num_threads) { // Convert linear index into bin and pixel indices. We make the within @@ -376,8 +376,8 @@ __global__ void RasterizePointsBackwardCudaKernel( float* grad_points) { // (P, 3) // Parallelized over each of K points per pixel, for each pixel in images of // size H * W, for each image in the batch of size N. - int num_threads = gridDim.x * blockDim.x; - int tid = blockIdx.x * blockDim.x + threadIdx.x; + auto num_threads = gridDim.x * blockDim.x; + auto tid = blockIdx.x * blockDim.x + threadIdx.x; for (int i = tid; i < N * H * W * K; i += num_threads) { // const int n = i / (H * W * K); // batch index (not needed). const int yxk = i % (H * W * K);