mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-03-11 23:06:04 +08:00
Fix CUDA kernel index data type in vision/fair/pytorch3d/pytorch3d/csrc/compositing/alpha_composite.cu +10
Summary: CUDA kernel variables matching the type `(thread|block|grid).(Idx|Dim).(x|y|z)` [have the data type `uint`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#built-in-variables). Many programmers mistakenly use implicit casts to turn these data types into `int`. In fact, the [CUDA Programming Guide](https://docs.nvidia.com/cuda/cuda-c-programming-guide/) it self is inconsistent and incorrect in its use of data types in programming examples. The result of these implicit casts is that our kernels may give unexpected results when exposed to large datasets, i.e., those exceeding >~2B items. While we now have linters in place to prevent simple mistakes (D71236150), our codebase has many problematic instances. This diff fixes some of them. Reviewed By: dtolnay Differential Revision: D71355356 fbshipit-source-id: cea44891416d9efd2f466d6c45df4e36008fa036
This commit is contained in:
committed by
Facebook GitHub Bot
parent
06a76ef8dd
commit
3987612062
@@ -260,8 +260,8 @@ __global__ void RasterizeMeshesNaiveCudaKernel(
|
||||
float* pix_dists,
|
||||
float* bary) {
|
||||
// Simple version: One thread per output pixel
|
||||
int num_threads = gridDim.x * blockDim.x;
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
auto num_threads = gridDim.x * blockDim.x;
|
||||
auto tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
for (int i = tid; i < N * H * W; i += num_threads) {
|
||||
// Convert linear index to 3D index
|
||||
@@ -446,8 +446,8 @@ __global__ void RasterizeMeshesBackwardCudaKernel(
|
||||
|
||||
// Parallelize over each pixel in images of
|
||||
// size H * W, for each image in the batch of size N.
|
||||
const int num_threads = gridDim.x * blockDim.x;
|
||||
const int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto num_threads = gridDim.x * blockDim.x;
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int t_i = tid; t_i < N * H * W; t_i += num_threads) {
|
||||
// Convert linear index to 3D index
|
||||
@@ -650,8 +650,8 @@ __global__ void RasterizeMeshesFineCudaKernel(
|
||||
) {
|
||||
// This can be more than H * W if H or W are not divisible by bin_size.
|
||||
int num_pixels = N * BH * BW * bin_size * bin_size;
|
||||
int num_threads = gridDim.x * blockDim.x;
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
auto num_threads = gridDim.x * blockDim.x;
|
||||
auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (int pid = tid; pid < num_pixels; pid += num_threads) {
|
||||
// Convert linear index into bin and pixel indices. We make the within
|
||||
|
||||
Reference in New Issue
Block a user