diff --git a/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu b/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu index 389bfaee..cd367e1c 100644 --- a/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu +++ b/pytorch3d/csrc/rasterize_coarse/rasterize_coarse.cu @@ -183,6 +183,22 @@ __global__ void RasterizeCoarseCudaKernel( // this effectively allocates space in the bin_faces array for the // elems in the current chunk that fall into this bin. const int start = atomicAdd(elems_per_bin + elems_per_bin_idx, count); + if (start + count > M) { + // The number of elems in this bin is so big that they won't fit. + // We print a warning using CUDA's printf. This may be invisible + // to notebook users, but apparent to others. It would be nice to + // also have a Python-friendly warning, but it is not obvious + // how to do this without slowing down the normal case. + const char* warning = + "Bin size was too small in the coarse rasterization phase. " + "This caused an overflow, meaning output may be incomplete. " + "To solve, " + "try increasing max_faces_per_bin / max_points_per_bin, " + "decreasing bin_size, " + "or setting bin_size to -1 to use the naive rasterization."; + printf(warning); + continue; + } // Now loop over the binmask and write the active bits for this bin // out to bin_faces.