mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
a formula for bin size for images over 64x64 (#90)
Summary: Signed-off-by: Michele Sanna <sanna@arrival.com> fixes the bin_size calculation with a formula for any image_size > 64. Matches the values chosen so far. simple test: ``` import numpy as np import matplotlib.pyplot as plt image_size = np.arange(64, 2048) bin_size = np.where(image_size <= 64, 8, (2 ** np.maximum(np.ceil(np.log2(image_size)) - 4, 4)).astype(int)) print(image_size) print(bin_size) for ims, bins in zip(image_size, bin_size): if ims <= 64: assert bins == 8 elif ims <= 256: assert bins == 16 elif ims <= 512: assert bins == 32 elif ims <= 1024: assert bins == 64 elif ims <= 2048: assert bins == 128 assert (ims + bins - 1) // bins < 22 plt.plot(image_size, bin_size) plt.grid() plt.show() ```  Pull Request resolved: https://github.com/facebookresearch/pytorch3d/pull/90 Reviewed By: jcjohnson Differential Revision: D21160372 Pulled By: nikhilaravi fbshipit-source-id: 660cf5832f4ca5be243c435a6bed969596fc0188
This commit is contained in:
parent
c3d636dc8c
commit
f8acecb6b3
@ -696,7 +696,7 @@ at::Tensor RasterizeMeshesCoarseCuda(
|
|||||||
const int num_bins = 1 + (image_size - 1) / bin_size; // Divide round up.
|
const int num_bins = 1 + (image_size - 1) / bin_size; // Divide round up.
|
||||||
const int M = max_faces_per_bin;
|
const int M = max_faces_per_bin;
|
||||||
|
|
||||||
if (num_bins >= 22) {
|
if (num_bins >= kMaxFacesPerBin) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "Got " << num_bins << "; that's too many!";
|
ss << "Got " << num_bins << "; that's too many!";
|
||||||
AT_ERROR(ss.str());
|
AT_ERROR(ss.str());
|
||||||
|
@ -17,6 +17,8 @@ __device__ inline float PixToNdc(int i, int S) {
|
|||||||
// TODO: is 8 enough? Would increasing have performance considerations?
|
// TODO: is 8 enough? Would increasing have performance considerations?
|
||||||
const int32_t kMaxPointsPerPixel = 150;
|
const int32_t kMaxPointsPerPixel = 150;
|
||||||
|
|
||||||
|
const int32_t kMaxFacesPerBin = 22;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ inline void BubbleSort(T* arr, int n) {
|
__device__ inline void BubbleSort(T* arr, int n) {
|
||||||
// Bubble sort. We only use it for tiny thread-local arrays (n < 8); in this
|
// Bubble sort. We only use it for tiny thread-local arrays (n < 8); in this
|
||||||
|
@ -11,6 +11,10 @@ from pytorch3d import _C
|
|||||||
# TODO make the epsilon user configurable
|
# TODO make the epsilon user configurable
|
||||||
kEpsilon = 1e-8
|
kEpsilon = 1e-8
|
||||||
|
|
||||||
|
# Maxinum number of faces per bins for
|
||||||
|
# coarse-to-fine rasterization
|
||||||
|
kMaxFacesPerBin = 22
|
||||||
|
|
||||||
|
|
||||||
def rasterize_meshes(
|
def rasterize_meshes(
|
||||||
meshes,
|
meshes,
|
||||||
@ -107,12 +111,23 @@ def rasterize_meshes(
|
|||||||
# TODO better heuristics for bin size.
|
# TODO better heuristics for bin size.
|
||||||
if image_size <= 64:
|
if image_size <= 64:
|
||||||
bin_size = 8
|
bin_size = 8
|
||||||
elif image_size <= 256:
|
else:
|
||||||
bin_size = 16
|
# Heuristic based formula maps image_size -> bin_size as follows:
|
||||||
elif image_size <= 512:
|
# image_size < 64 -> 8
|
||||||
bin_size = 32
|
# 16 < image_size < 256 -> 16
|
||||||
elif image_size <= 1024:
|
# 256 < image_size < 512 -> 32
|
||||||
bin_size = 64
|
# 512 < image_size < 1024 -> 64
|
||||||
|
# 1024 < image_size < 2048 -> 128
|
||||||
|
bin_size = int(2 ** max(np.ceil(np.log2(image_size)) - 4, 4))
|
||||||
|
|
||||||
|
if bin_size != 0:
|
||||||
|
# There is a limit on the number of faces per bin in the cuda kernel.
|
||||||
|
faces_per_bin = 1 + (image_size - 1) // bin_size
|
||||||
|
if faces_per_bin >= kMaxFacesPerBin:
|
||||||
|
raise ValueError(
|
||||||
|
"bin_size too small, number of faces per bin must be less than %d; got %d"
|
||||||
|
% (kMaxFacesPerBin, faces_per_bin)
|
||||||
|
)
|
||||||
|
|
||||||
if max_faces_per_bin is None:
|
if max_faces_per_bin is None:
|
||||||
max_faces_per_bin = int(max(10000, verts_packed.shape[0] / 5))
|
max_faces_per_bin = int(max(10000, verts_packed.shape[0] / 5))
|
||||||
|
@ -7,6 +7,11 @@ from pytorch3d import _C
|
|||||||
from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc
|
from pytorch3d.renderer.mesh.rasterize_meshes import pix_to_ndc
|
||||||
|
|
||||||
|
|
||||||
|
# Maxinum number of faces per bins for
|
||||||
|
# coarse-to-fine rasterization
|
||||||
|
kMaxPointsPerBin = 22
|
||||||
|
|
||||||
|
|
||||||
# TODO(jcjohns): Support non-square images
|
# TODO(jcjohns): Support non-square images
|
||||||
def rasterize_points(
|
def rasterize_points(
|
||||||
pointclouds,
|
pointclouds,
|
||||||
@ -82,6 +87,15 @@ def rasterize_points(
|
|||||||
elif image_size <= 1024:
|
elif image_size <= 1024:
|
||||||
bin_size = 64
|
bin_size = 64
|
||||||
|
|
||||||
|
if bin_size != 0:
|
||||||
|
# There is a limit on the number of points per bin in the cuda kernel.
|
||||||
|
points_per_bin = 1 + (image_size - 1) // bin_size
|
||||||
|
if points_per_bin >= kMaxPointsPerBin:
|
||||||
|
raise ValueError(
|
||||||
|
"bin_size too small, number of points per bin must be less than %d; got %d"
|
||||||
|
% (kMaxPointsPerBin, points_per_bin)
|
||||||
|
)
|
||||||
|
|
||||||
if max_points_per_bin is None:
|
if max_points_per_bin is None:
|
||||||
max_points_per_bin = int(max(10000, points_packed.shape[0] / 5))
|
max_points_per_bin = int(max(10000, points_packed.shape[0] / 5))
|
||||||
|
|
||||||
|
@ -382,6 +382,13 @@ class TestRasterizeMeshes(TestCaseMixin, unittest.TestCase):
|
|||||||
args = ()
|
args = ()
|
||||||
self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True)
|
self._compare_impls(fn1, fn2, args, args, verts1, verts2, compare_grads=True)
|
||||||
|
|
||||||
|
def test_bin_size_error(self):
|
||||||
|
meshes = ico_sphere(2)
|
||||||
|
image_size = 1024
|
||||||
|
bin_size = 16
|
||||||
|
with self.assertRaisesRegex(ValueError, "bin_size too small"):
|
||||||
|
rasterize_meshes(meshes, image_size, 0.0, 2, bin_size)
|
||||||
|
|
||||||
def _test_back_face_culling(self, rasterize_meshes_fn, device, bin_size):
|
def _test_back_face_culling(self, rasterize_meshes_fn, device, bin_size):
|
||||||
# Square based pyramid mesh.
|
# Square based pyramid mesh.
|
||||||
# fmt: off
|
# fmt: off
|
||||||
|
@ -212,6 +212,13 @@ class TestRasterizePoints(TestCaseMixin, unittest.TestCase):
|
|||||||
if compare_grads:
|
if compare_grads:
|
||||||
self.assertClose(grad_points1, grad_points2, atol=2e-6)
|
self.assertClose(grad_points1, grad_points2, atol=2e-6)
|
||||||
|
|
||||||
|
def test_bin_size_error(self):
|
||||||
|
points = Pointclouds(points=torch.rand(5, 100, 3))
|
||||||
|
image_size = 1024
|
||||||
|
bin_size = 16
|
||||||
|
with self.assertRaisesRegex(ValueError, "bin_size too small"):
|
||||||
|
rasterize_points(points, image_size, 0.0, 2, bin_size=bin_size)
|
||||||
|
|
||||||
def _test_behind_camera(self, rasterize_points_fn, device, bin_size=None):
|
def _test_behind_camera(self, rasterize_points_fn, device, bin_size=None):
|
||||||
# Test case where all points are behind the camera -- nothing should
|
# Test case where all points are behind the camera -- nothing should
|
||||||
# get rasterized
|
# get rasterized
|
||||||
|
Loading…
x
Reference in New Issue
Block a user