Fix CUDA kernel index data type in vision/fair/pytorch3d/pytorch3d/csrc/compositing/alpha_composite.cu +10

Summary:
CUDA kernel variables matching the type `(thread|block|grid).(Idx|Dim).(x|y|z)` [have the data type `uint`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#built-in-variables).

Many programmers mistakenly use implicit casts to turn these data types into `int`. In fact, the [CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/) it self is inconsistent and incorrect in its use of data types in programming examples.

The result of these implicit casts is that our kernels may give unexpected results when exposed to large datasets, i.e., those exceeding >~2B items.

While we now have linters in place to prevent simple mistakes (D71236150), our codebase has many problematic instances. This diff fixes some of them.

Reviewed By: dtolnay

Differential Revision: D71355356

fbshipit-source-id: cea44891416d9efd2f466d6c45df4e36008fa036
This commit is contained in:
Richard Barnes 2025-03-19 13:21:43 -07:00 committed by Facebook GitHub Bot
parent 06a76ef8dd
commit 3987612062
10 changed files with 54 additions and 54 deletions

View File

@ -33,11 +33,11 @@ __global__ void alphaCompositeCudaForwardKernel(
const int64_t W = points_idx.size(3);
// Get the batch and index
const int batch = blockIdx.x;
const auto batch = blockIdx.x;
const int num_pixels = C * H * W;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
// Iterate over each feature in each pixel
for (int pid = tid; pid < num_pixels; pid += num_threads) {
@ -83,11 +83,11 @@ __global__ void alphaCompositeCudaBackwardKernel(
const int64_t W = points_idx.size(3);
// Get the batch and index
const int batch = blockIdx.x;
const auto batch = blockIdx.x;
const int num_pixels = C * H * W;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
// Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size

View File

@ -33,11 +33,11 @@ __global__ void weightedSumNormCudaForwardKernel(
const int64_t W = points_idx.size(3);
// Get the batch and index
const int batch = blockIdx.x;
const auto batch = blockIdx.x;
const int num_pixels = C * H * W;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
// Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size
@ -96,11 +96,11 @@ __global__ void weightedSumNormCudaBackwardKernel(
const int64_t W = points_idx.size(3);
// Get the batch and index
const int batch = blockIdx.x;
const auto batch = blockIdx.x;
const int num_pixels = C * W * H;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
// Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size

View File

@ -31,11 +31,11 @@ __global__ void weightedSumCudaForwardKernel(
const int64_t W = points_idx.size(3);
// Get the batch and index
const int batch = blockIdx.x;
const auto batch = blockIdx.x;
const int num_pixels = C * H * W;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
// Parallelize over each feature in each pixel in images of size H * W,
// for each image in the batch of size batch_size
@ -78,11 +78,11 @@ __global__ void weightedSumCudaBackwardKernel(
const int64_t W = points_idx.size(3);
// Get the batch and index
const int batch = blockIdx.x;
const auto batch = blockIdx.x;
const int num_pixels = C * H * W;
const int num_threads = gridDim.y * blockDim.x;
const int tid = blockIdx.y * blockDim.x + threadIdx.x;
const auto num_threads = gridDim.y * blockDim.x;
const auto tid = blockIdx.y * blockDim.x + threadIdx.x;
// Iterate over each pixel to compute the contribution to the
// gradient for the features and weights

View File

@ -20,14 +20,14 @@ __global__ void GatherScatterCudaKernel(
const size_t V,
const size_t D,
const size_t E) {
const int tid = threadIdx.x;
const auto tid = threadIdx.x;
// Reverse the vertex order if backward.
const int v0_idx = backward ? 1 : 0;
const int v1_idx = backward ? 0 : 1;
// Edges are split evenly across the blocks.
for (int e = blockIdx.x; e < E; e += gridDim.x) {
for (auto e = blockIdx.x; e < E; e += gridDim.x) {
// Get indices of vertices which form the edge.
const int64_t v0 = edges[2 * e + v0_idx];
const int64_t v1 = edges[2 * e + v1_idx];
@ -35,7 +35,7 @@ __global__ void GatherScatterCudaKernel(
// Split vertex features evenly across threads.
// This implementation will be quite wasteful when D<128 since there will be
// a lot of threads doing nothing.
for (int d = tid; d < D; d += blockDim.x) {
for (auto d = tid; d < D; d += blockDim.x) {
const float val = input[v1 * D + d];
float* address = output + v0 * D + d;
atomicAdd(address, val);

View File

@ -20,8 +20,8 @@ __global__ void InterpFaceAttrsForwardKernel(
const size_t P,
const size_t F,
const size_t D) {
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
const auto num_threads = blockDim.x * gridDim.x;
for (int pd = tid; pd < P * D; pd += num_threads) {
const int p = pd / D;
const int d = pd % D;
@ -93,8 +93,8 @@ __global__ void InterpFaceAttrsBackwardKernel(
const size_t P,
const size_t F,
const size_t D) {
const int tid = threadIdx.x + blockIdx.x * blockDim.x;
const int num_threads = blockDim.x * gridDim.x;
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
const auto num_threads = blockDim.x * gridDim.x;
for (int pd = tid; pd < P * D; pd += num_threads) {
const int p = pd / D;
const int d = pd % D;

View File

@ -110,7 +110,7 @@ __global__ void DistanceForwardKernel(
__syncthreads();
// Perform reduction in shared memory.
for (int s = blockDim.x / 2; s > 32; s >>= 1) {
for (auto s = blockDim.x / 2; s > 32; s >>= 1) {
if (tid < s) {
if (min_dists[tid] > min_dists[tid + s]) {
min_dists[tid] = min_dists[tid + s];
@ -502,8 +502,8 @@ __global__ void PointFaceArrayForwardKernel(
const float3* tris_f3 = (float3*)tris;
// Parallelize over P * S computations
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int t_i = tid; t_i < P * T; t_i += num_threads) {
const int t = t_i / P; // segment index.
@ -576,8 +576,8 @@ __global__ void PointFaceArrayBackwardKernel(
const float3* tris_f3 = (float3*)tris;
// Parallelize over P * S computations
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int t_i = tid; t_i < P * T; t_i += num_threads) {
const int t = t_i / P; // triangle index.
@ -683,8 +683,8 @@ __global__ void PointEdgeArrayForwardKernel(
float3* segms_f3 = (float3*)segms;
// Parallelize over P * S computations
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
const int s = t_i / P; // segment index.
@ -752,8 +752,8 @@ __global__ void PointEdgeArrayBackwardKernel(
float3* segms_f3 = (float3*)segms;
// Parallelize over P * S computations
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int t_i = tid; t_i < P * S; t_i += num_threads) {
const int s = t_i / P; // segment index.

View File

@ -25,7 +25,7 @@ class BitMask {
// Use all threads in the current block to clear all bits of this BitMask
__device__ void block_clear() {
for (int i = threadIdx.x; i < H * W * D; i += blockDim.x) {
for (auto i = threadIdx.x; i < H * W * D; i += blockDim.x) {
data[i] = 0;
}
__syncthreads();

View File

@ -23,8 +23,8 @@ __global__ void TriangleBoundingBoxKernel(
const float blur_radius,
float* bboxes, // (4, F)
bool* skip_face) { // (F,)
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = blockDim.x * gridDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto num_threads = blockDim.x * gridDim.x;
const float sqrt_radius = sqrt(blur_radius);
for (int f = tid; f < F; f += num_threads) {
const float v0x = face_verts[f * 9 + 0 * 3 + 0];
@ -56,8 +56,8 @@ __global__ void PointBoundingBoxKernel(
const int P,
float* bboxes, // (4, P)
bool* skip_points) {
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = blockDim.x * gridDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto num_threads = blockDim.x * gridDim.x;
for (int p = tid; p < P; p += num_threads) {
const float x = points[p * 3 + 0];
const float y = points[p * 3 + 1];
@ -113,7 +113,7 @@ __global__ void RasterizeCoarseCudaKernel(
const int chunks_per_batch = 1 + (E - 1) / chunk_size;
const int num_chunks = N * chunks_per_batch;
for (int chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
for (auto chunk = blockIdx.x; chunk < num_chunks; chunk += gridDim.x) {
const int batch_idx = chunk / chunks_per_batch; // batch index
const int chunk_idx = chunk % chunks_per_batch;
const int elem_chunk_start_idx = chunk_idx * chunk_size;
@ -123,7 +123,7 @@ __global__ void RasterizeCoarseCudaKernel(
const int64_t elem_stop_idx = elem_start_idx + elems_per_batch[batch_idx];
// Have each thread handle a different face within the chunk
for (int e = threadIdx.x; e < chunk_size; e += blockDim.x) {
for (auto e = threadIdx.x; e < chunk_size; e += blockDim.x) {
const int e_idx = elem_chunk_start_idx + e;
// Check that we are still within the same element of the batch
@ -170,7 +170,7 @@ __global__ void RasterizeCoarseCudaKernel(
// Now we have processed every elem in the current chunk. We need to
// count the number of elems in each bin so we can write the indices
// out to global memory. We have each thread handle a different bin.
for (int byx = threadIdx.x; byx < num_bins_y * num_bins_x;
for (auto byx = threadIdx.x; byx < num_bins_y * num_bins_x;
byx += blockDim.x) {
const int by = byx / num_bins_x;
const int bx = byx % num_bins_x;

View File

@ -260,8 +260,8 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
float* pix_dists,
float* bary) {
// Simple version: One thread per output pixel
int num_threads = gridDim.x * blockDim.x;
int tid = blockDim.x * blockIdx.x + threadIdx.x;
auto num_threads = gridDim.x * blockDim.x;
auto tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < N * H * W; i += num_threads) {
// Convert linear index to 3D index
@ -446,8 +446,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
// Parallelize over each pixel in images of
// size H * W, for each image in the batch of size N.
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int t_i = tid; t_i < N * H * W; t_i += num_threads) {
// Convert linear index to 3D index
@ -650,8 +650,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
) {
// This can be more than H * W if H or W are not divisible by bin_size.
int num_pixels = N * BH * BW * bin_size * bin_size;
int num_threads = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
auto num_threads = gridDim.x * blockDim.x;
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int pid = tid; pid < num_pixels; pid += num_threads) {
// Convert linear index into bin and pixel indices. We make the within

View File

@ -97,8 +97,8 @@ __global__ void RasterizePointsNaiveCudaKernel(
float* zbuf, // (N, H, W, K)
float* pix_dists) { // (N, H, W, K)
// Simple version: One thread per output pixel
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockDim.x * blockIdx.x + threadIdx.x;
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = tid; i < N * H * W; i += num_threads) {
// Convert linear index to 3D index
const int n = i / (H * W); // Batch index
@ -237,8 +237,8 @@ __global__ void RasterizePointsFineCudaKernel(
float* pix_dists) { // (N, H, W, K)
// This can be more than H * W if H or W are not divisible by bin_size.
const int num_pixels = N * BH * BW * bin_size * bin_size;
const int num_threads = gridDim.x * blockDim.x;
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto num_threads = gridDim.x * blockDim.x;
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int pid = tid; pid < num_pixels; pid += num_threads) {
// Convert linear index into bin and pixel indices. We make the within
@ -376,8 +376,8 @@ __global__ void RasterizePointsBackwardCudaKernel(
float* grad_points) { // (P, 3)
// Parallelized over each of K points per pixel, for each pixel in images of
// size H * W, for each image in the batch of size N.
int num_threads = gridDim.x * blockDim.x;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
auto num_threads = gridDim.x * blockDim.x;
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int i = tid; i < N * H * W * K; i += num_threads) {
// const int n = i / (H * W * K); // batch index (not needed).
const int yxk = i % (H * W * K);