From cac6cb1b7813a4f09a05e0ade43c63292bb08b79 Mon Sep 17 00:00:00 2001 From: Ignacio Rocco Date: Fri, 5 Nov 2021 10:28:51 -0700 Subject: [PATCH] Update NDC raysampler for non-square convention (#29) Summary: - Old NDC convention had xy coords in [-1,1]x[-1,1] - New NDC convention has xy coords in [-1, 1]x[-u, u] or [-u, u]x[-1, 1] where u > 1 is the aspect ratio of the image. This PR fixes the NDC raysampler to use the new convention. Partial fix for https://github.com/facebookresearch/pytorch3d/issues/868 Pull Request resolved: https://github.com/fairinternal/pytorch3d/pull/29 Reviewed By: davnov134 Differential Revision: D31926148 Pulled By: bottler fbshipit-source-id: c6c42c60d1473b04e60ceb49c8c10951ddf03c74 --- pytorch3d/renderer/implicit/raysampling.py | 23 +++--- tests/test_raysampling.py | 82 ++++++++++++++++++++-- tests/test_render_implicit.py | 24 ++++--- tests/test_render_volumes.py | 24 ++++--- 4 files changed, 118 insertions(+), 35 deletions(-) diff --git a/pytorch3d/renderer/implicit/raysampling.py b/pytorch3d/renderer/implicit/raysampling.py index 85bf5bf8..f0c6f5e1 100644 --- a/pytorch3d/renderer/implicit/raysampling.py +++ b/pytorch3d/renderer/implicit/raysampling.py @@ -139,8 +139,8 @@ class NDCGridRaysampler(GridRaysampler): have uniformly-spaced z-coordinates between a predefined minimum and maximum depth. `NDCGridRaysampler` follows the screen conventions of the `Meshes` and `Pointclouds` - renderers. I.e. the border of the leftmost / rightmost / topmost / bottommost pixel - has coordinates 1.0 / -1.0 / 1.0 / -1.0 respectively. + renderers. I.e. the pixel coordinates are in [-1, 1]x[-u, u] or [-u, u]x[-1, 1] + where u > 1 is the aspect ratio of the image. """ def __init__( @@ -159,13 +159,20 @@ class NDCGridRaysampler(GridRaysampler): min_depth: The minimum depth of a ray-point. max_depth: The maximum depth of a ray-point. """ - half_pix_width = 1.0 / image_width - half_pix_height = 1.0 / image_height + if image_width >= image_height: + range_x = image_width / image_height + range_y = 1.0 + else: + range_x = 1.0 + range_y = image_height / image_width + + half_pix_width = range_x / image_width + half_pix_height = range_y / image_height super().__init__( - min_x=1.0 - half_pix_width, - max_x=-1.0 + half_pix_width, - min_y=1.0 - half_pix_height, - max_y=-1.0 + half_pix_height, + min_x=range_x - half_pix_width, + max_x=-range_x + half_pix_width, + min_y=range_y - half_pix_height, + max_y=-range_y + half_pix_height, image_width=image_width, image_height=image_height, n_pts_per_ray=n_pts_per_ray, diff --git a/tests/test_raysampling.py b/tests/test_raysampling.py index 9c7e3de3..9d1409ee 100644 --- a/tests/test_raysampling.py +++ b/tests/test_raysampling.py @@ -24,6 +24,69 @@ from pytorch3d.transforms import Rotate from test_cameras import init_random_cameras +class TestNDCRaysamplerConvention(TestCaseMixin, unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(42) + + def test_ndc_convention( + self, + h=428, + w=760, + ): + device = torch.device("cuda") + + camera = init_random_cameras(PerspectiveCameras, 1, random_z=True).to(device) + + depth_map = torch.ones((1, 1, h, w)).to(device) + + xyz = ray_bundle_to_ray_points( + NDCGridRaysampler( + image_width=w, + image_height=h, + n_pts_per_ray=1, + min_depth=1.0, + max_depth=1.0, + )(camera)._replace(lengths=depth_map[:, 0, ..., None]) + ).view(1, -1, 3) + + # project pointcloud + xy = camera.transform_points(xyz)[:, :, :2].squeeze() + + xy_grid = self._get_ndc_grid(h, w, device) + + self.assertClose( + xy, + xy_grid, + atol=1e-4, + ) + + def _get_ndc_grid(self, h, w, device): + if w >= h: + range_x = w / h + range_y = 1.0 + else: + range_x = 1.0 + range_y = h / w + + half_pix_width = range_x / w + half_pix_height = range_y / h + + min_x = range_x - half_pix_width + max_x = -range_x + half_pix_width + min_y = range_y - half_pix_height + max_y = -range_y + half_pix_height + + y_grid, x_grid = torch.meshgrid( + torch.linspace(min_y, max_y, h, dtype=torch.float32), + torch.linspace(min_x, max_x, w, dtype=torch.float32), + ) + + x_points = x_grid.contiguous().view(-1).to(device) + y_points = y_grid.contiguous().view(-1).to(device) + xy = torch.stack((x_points, y_points), dim=1) + return xy + + class TestRaysampling(TestCaseMixin, unittest.TestCase): def setUp(self) -> None: torch.manual_seed(42) @@ -147,12 +210,19 @@ class TestRaysampling(TestCaseMixin, unittest.TestCase): if issubclass(raysampler_type, NDCGridRaysampler): # adjust the gt bounds for NDCGridRaysampler - half_pix_width = 1.0 / image_width - half_pix_height = 1.0 / image_height - min_x_ = 1.0 - half_pix_width - max_x_ = -1.0 + half_pix_width - min_y_ = 1.0 - half_pix_height - max_y_ = -1.0 + half_pix_height + if image_width >= image_height: + range_x = image_width / image_height + range_y = 1.0 + else: + range_x = 1.0 + range_y = image_height / image_width + + half_pix_width = range_x / image_width + half_pix_height = range_y / image_height + min_x_ = range_x - half_pix_width + max_x_ = -range_x + half_pix_width + min_y_ = range_y - half_pix_height + max_y_ = -range_y + half_pix_height else: min_x_ = min_x max_x_ = max_x diff --git a/tests/test_render_implicit.py b/tests/test_render_implicit.py index 7a4ce851..22be1576 100644 --- a/tests/test_render_implicit.py +++ b/tests/test_render_implicit.py @@ -159,8 +159,12 @@ class TestRenderImplicit(TestCaseMixin, unittest.TestCase): with self.assertRaises(ValueError): renderer(cameras=cameras, volumetric_function=bad_volumetric_function) - def test_compare_with_meshes_renderer( - self, batch_size=11, image_size=100, sphere_diameter=0.6 + def test_compare_with_meshes_renderer(self): + self._compare_with_meshes_renderer(image_size=(200, 100)) + self._compare_with_meshes_renderer(image_size=(100, 200)) + + def _compare_with_meshes_renderer( + self, image_size, batch_size=11, sphere_diameter=0.6 ): """ Generate a spherical RGB volumetric function and its corresponding mesh @@ -169,9 +173,7 @@ class TestRenderImplicit(TestCaseMixin, unittest.TestCase): """ # generate NDC camera extrinsics and intrinsics - cameras = init_cameras( - batch_size, image_size=[image_size, image_size], ndc=True - ) + cameras = init_cameras(batch_size, image_size=image_size, ndc=True) # get rand offset of the volume sphere_centroid = torch.randn(batch_size, 3, device=cameras.device) * 0.1 @@ -179,8 +181,8 @@ class TestRenderImplicit(TestCaseMixin, unittest.TestCase): # init the grid raysampler with the ndc grid raysampler = NDCGridRaysampler( - image_width=image_size, - image_height=image_size, + image_width=image_size[1], + image_height=image_size[0], n_pts_per_ray=256, min_depth=0.1, max_depth=2.0, @@ -336,9 +338,11 @@ class TestRenderImplicit(TestCaseMixin, unittest.TestCase): self.assertClose(mu_diff, torch.zeros_like(mu_diff), atol=5e-2) self.assertClose(std_diff, torch.zeros_like(std_diff), atol=6e-2) - def test_rotating_gif( - self, n_frames=50, fps=15, image_size=(100, 100), sphere_diameter=0.5 - ): + def test_rotating_gif(self): + self._rotating_gif(image_size=(200, 100)) + self._rotating_gif(image_size=(100, 200)) + + def _rotating_gif(self, image_size, n_frames=50, fps=15, sphere_diameter=0.5): """ Render a gif animation of a rotating sphere (runs only if `DEBUG==True`). """ diff --git a/tests/test_render_volumes.py b/tests/test_render_volumes.py index 88869cbc..39d1e59b 100644 --- a/tests/test_render_volumes.py +++ b/tests/test_render_volumes.py @@ -164,7 +164,7 @@ def init_cameras( p0 = torch.ones(batch_size, 2, device=device) p0[:, 0] *= image_size[1] * 0.5 p0[:, 1] *= image_size[0] * 0.5 - focal = image_size[0] * torch.ones(batch_size, device=device) + focal = max(*image_size) * torch.ones(batch_size, device=device) # convert to a Camera object cameras = PerspectiveCameras(focal, p0, R=R, T=T, device=device) @@ -295,7 +295,7 @@ class TestRenderVolumes(TestCaseMixin, unittest.TestCase): _validate_ray_bundle_variables(*bad_ray_bundle) def test_compare_with_pointclouds_renderer( - self, batch_size=11, volume_size=(30, 30, 30), image_size=200 + self, batch_size=11, volume_size=(30, 30, 30), image_size=(200, 250) ): """ Generate a volume and its corresponding point cloud and check whether @@ -303,9 +303,7 @@ class TestRenderVolumes(TestCaseMixin, unittest.TestCase): """ # generate NDC camera extrinsics and intrinsics - cameras = init_cameras( - batch_size, image_size=[image_size, image_size], ndc=True - ) + cameras = init_cameras(batch_size, image_size=image_size, ndc=True) # init the boundary volume for shape in ("sphere", "cube"): @@ -340,10 +338,10 @@ class TestRenderVolumes(TestCaseMixin, unittest.TestCase): # init the grid raysampler with the ndc grid coord_range = 1.0 - half_pix_size = coord_range / image_size + half_pix_size = coord_range / max(*image_size) raysampler = NDCGridRaysampler( - image_width=image_size, - image_height=image_size, + image_width=image_size[1], + image_height=image_size[0], n_pts_per_ray=256, min_depth=0.1, max_depth=2.0, @@ -499,8 +497,12 @@ class TestRenderVolumes(TestCaseMixin, unittest.TestCase): images_opacities_mc.permute(0, 3, 1, 2), images_opacities_mc_, atol=1e-4 ) - def test_rotating_gif( - self, n_frames=50, fps=15, volume_size=(100, 100, 100), image_size=(100, 100) + def test_rotating_gif(self): + self._rotating_gif(image_size=(200, 100)) + self._rotating_gif(image_size=(100, 200)) + + def _rotating_gif( + self, image_size, n_frames=50, fps=15, volume_size=(100, 100, 100) ): """ Render a gif animation of a rotating cube/sphere (runs only if `DEBUG==True`). @@ -586,7 +588,7 @@ class TestRenderVolumes(TestCaseMixin, unittest.TestCase): # batch_size = 4 sides of the cube batch_size = 4 - image_size = (50, 50) + image_size = (50, 40) for volume_size in ([25, 25, 25],): for sample_mode in ("bilinear", "nearest"):