diff --git a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu index a11285e0..feb2c59a 100644 --- a/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu +++ b/pytorch3d/csrc/rasterize_meshes/rasterize_meshes.cu @@ -696,7 +696,7 @@ at::Tensor RasterizeMeshesCoarseCuda( const int num_bins = 1 + (image_size - 1) / bin_size; // Divide round up. const int M = max_faces_per_bin; - if (num_bins >= 22) { + if (num_bins >= kMaxFacesPerBin) { std::stringstream ss; ss << "Got " << num_bins << "; that's too many!"; AT_ERROR(ss.str()); diff --git a/pytorch3d/csrc/rasterize_points/rasterization_utils.cuh b/pytorch3d/csrc/rasterize_points/rasterization_utils.cuh index 45fdda26..ddb585ba 100644 --- a/pytorch3d/csrc/rasterize_points/rasterization_utils.cuh +++ b/pytorch3d/csrc/rasterize_points/rasterization_utils.cuh @@ -17,6 +17,8 @@ __device__ inline float PixToNdc(int i, int S) { // TODO: is 8 enough? Would increasing have performance considerations? const int32_t kMaxPointsPerPixel = 150; +const int32_t kMaxFacesPerBin = 22; + template __device__ inline void BubbleSort(T* arr, int n) { // Bubble sort. We only use it for tiny thread-local arrays (n < 8); in this diff --git a/pytorch3d/renderer/mesh/rasterize_meshes.py b/pytorch3d/renderer/mesh/rasterize_meshes.py index 67cb8a79..78ab7759 100644 --- a/pytorch3d/renderer/mesh/rasterize_meshes.py +++ b/pytorch3d/renderer/mesh/rasterize_meshes.py @@ -11,6 +11,10 @@ from pytorch3d import _C # TODO make the epsilon user configurable kEpsilon = 1e-8 +# Maxinum number of faces per bins for +# coarse-to-fine rasterization +kMaxFacesPerBin = 22 + def rasterize_meshes( meshes, @@ -107,12 +111,23 @@ def rasterize_meshes( # TODO better heuristics for bin size. if image_size <= 64: bin_size = 8 - elif image_size <= 256: - bin_size = 16 - elif image_size <= 512: - bin_size = 32 - elif image_size <= 1024: - bin_size = 64 + else: + # Heuristic based formula maps image_size -> bin_size as follows: + # image_size < 64 -> 8 + # 16 < image_size < 256 -> 16 + # 256 < image_size < 512 -> 32 + # 512 < image_size < 1024 -> 64 + # 1024 < image_size < 2048 -> 128 + bin_size = int(2 ** max(np.ceil(np.log2(image_size)) - 4, 4)) + + if bin_size != 0: + # There is a limit on the number of faces per bin in the cuda kernel. + faces_per_bin = 1 + (image_size - 1) // bin_size + if faces_per_bin >= kMaxFacesPerBin: + raise ValueError( + "bin_size too small, number of faces per bin must be less than %d; got %d" + % (kMaxFacesPerBin, faces_per_bin) + ) if max_faces_per_bin is None: max_faces_per_bin = int(max(10000, verts_packed.shape[0] / 5)) diff --git a/pytorch3d/renderer/points/rasterize_points.py b/pytorch3d/renderer/points/rasterize_points.py index 6e3e71a6..51ff2815 100644 --- a/pytorch3d/renderer/points/rasterize_points.py +++ b/pytorch3d/renderer/points/rasterize_points.py @@ -7,6 +7,11 @@ from pytorch3d import _C from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc +# Maxinum number of faces per bins for +# coarse-to-fine rasterization +kMaxPointsPerBin = 22 + + # TODO(jcjohns): Support non-square images def rasterize_points( pointclouds, @@ -82,6 +87,15 @@ def rasterize_points( elif image_size <= 1024: bin_size = 64 + if bin_size != 0: + # There is a limit on the number of points per bin in the cuda kernel. + points_per_bin = 1 + (image_size - 1) // bin_size + if points_per_bin >= kMaxPointsPerBin: + raise ValueError( + "bin_size too small, number of points per bin must be less than %d; got %d" + % (kMaxPointsPerBin, points_per_bin) + ) + if max_points_per_bin is None: max_points_per_bin = int(max(10000, points_packed.shape[0] / 5)) diff --git a/tests/test_rasterize_meshes.py b/tests/test_rasterize_meshes.py index bb2441d8..a2530209 100644 --- a/tests/test_rasterize_meshes.py +++ b/tests/test_rasterize_meshes.py @@ -382,6 +382,13 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase): args = () self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True) + def test_bin_size_error(self): + meshes = ico_sphere(2) + image_size = 1024 + bin_size = 16 + with self.assertRaisesRegex(ValueError, "bin_size too small"): + rasterize_meshes(meshes, image_size, 0.0, 2, bin_size) + def _test_back_face_culling(self, rasterize_meshes_fn, device, bin_size): # Square based pyramid mesh. # fmt: off diff --git a/tests/test_rasterize_points.py b/tests/test_rasterize_points.py index e46dc56b..a7591c52 100644 --- a/tests/test_rasterize_points.py +++ b/tests/test_rasterize_points.py @@ -212,6 +212,13 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase): if compare_grads: self.assertClose(grad_points1, grad_points2, atol=2e-6) + def test_bin_size_error(self): + points = Pointclouds(points=torch.rand(5, 100, 3)) + image_size = 1024 + bin_size = 16 + with self.assertRaisesRegex(ValueError, "bin_size too small"): + rasterize_points(points, image_size, 0.0, 2, bin_size=bin_size) + def _test_behind_camera(self, rasterize_points_fn, device, bin_size=None): # Test case where all points are behind the camera -- nothing should # get rasterized