From 38fd8380f77ba5412eca1698ad9eabe8234c01ce Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Tue, 19 Jul 2022 10:38:13 -0700 Subject: [PATCH] fix ndc/screen problem in blender/llff (#39) Summary: X-link: https://github.com/fairinternal/pytorch3d/pull/39 Blender and LLFF cameras were sending screen space focal length and principal point to a camera init function expecting NDC Reviewed By: shapovalov Differential Revision: D37788686 fbshipit-source-id: 2ddf7436248bc0d174eceb04c288b93858138582 --- .../dataset/blender_dataset_map_provider.py | 3 +-- .../dataset/llff_dataset_map_provider.py | 4 ++-- pytorch3d/implicitron/dataset/load_blender.py | 18 ++++++++++++---- .../dataset/single_sequence_dataset.py | 7 ++++--- tests/implicitron/test_data_llff.py | 21 +++++++++++++++++++ 5 files changed, 42 insertions(+), 11 deletions(-) diff --git a/pytorch3d/implicitron/dataset/blender_dataset_map_provider.py b/pytorch3d/implicitron/dataset/blender_dataset_map_provider.py index c06b4313..2eab2560 100644 --- a/pytorch3d/implicitron/dataset/blender_dataset_map_provider.py +++ b/pytorch3d/implicitron/dataset/blender_dataset_map_provider.py @@ -41,11 +41,10 @@ class BlenderDatasetMapProvider(SingleSceneDatasetMapProviderBase): path_manager=path_manager, ) H, W, focal = hwf - H, W = int(H), int(W) images_masks = torch.from_numpy(images).permute(0, 3, 1, 2) # pyre-ignore[16] - self.poses = _interpret_blender_cameras(poses, H, W, focal) + self.poses = _interpret_blender_cameras(poses, focal) # pyre-ignore[16] self.images = images_masks[:, :3] # pyre-ignore[16] diff --git a/pytorch3d/implicitron/dataset/llff_dataset_map_provider.py b/pytorch3d/implicitron/dataset/llff_dataset_map_provider.py index d9ea2917..5ebd617e 100644 --- a/pytorch3d/implicitron/dataset/llff_dataset_map_provider.py +++ b/pytorch3d/implicitron/dataset/llff_dataset_map_provider.py @@ -49,12 +49,12 @@ class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase): ) i_split = (i_train, i_test, i_test) H, W, focal = hwf - H, W = int(H), int(W) + focal_ndc = 2 * focal / min(H, W) images = torch.from_numpy(images).permute(0, 3, 1, 2) poses = torch.from_numpy(poses) # pyre-ignore[16] - self.poses = _interpret_blender_cameras(poses, H, W, focal) + self.poses = _interpret_blender_cameras(poses, focal_ndc) # pyre-ignore[16] self.images = images # pyre-ignore[16] diff --git a/pytorch3d/implicitron/dataset/load_blender.py b/pytorch3d/implicitron/dataset/load_blender.py index f1bdeb1c..42b9cb53 100644 --- a/pytorch3d/implicitron/dataset/load_blender.py +++ b/pytorch3d/implicitron/dataset/load_blender.py @@ -46,7 +46,12 @@ def _local_path(path_manager, path): def load_blender_data( - basedir, half_res=False, testskip=1, debug=False, path_manager=None + basedir, + half_res=False, + testskip=1, + debug=False, + path_manager=None, + focal_length_in_screen_space=False, ): splits = ["train", "val", "test"] metas = {} @@ -84,7 +89,10 @@ def load_blender_data( H, W = imgs[0].shape[:2] camera_angle_x = float(meta["camera_angle_x"]) - focal = 0.5 * W / np.tan(0.5 * camera_angle_x) + if focal_length_in_screen_space: + focal = 0.5 * W / np.tan(0.5 * camera_angle_x) + else: + focal = 1 / np.tan(0.5 * camera_angle_x) render_poses = torch.stack( [ @@ -100,7 +108,8 @@ def load_blender_data( H = H // 32 W = W // 32 - focal = focal / 32.0 + if focal_length_in_screen_space: + focal = focal / 32.0 imgs = [ torch.from_numpy( cv2.resize(imgs[i], dsize=(25, 25), interpolation=cv2.INTER_AREA) @@ -117,7 +126,8 @@ def load_blender_data( # TODO: resize images using INTER_AREA (cv2) H = H // 2 W = W // 2 - focal = focal / 2.0 + if focal_length_in_screen_space: + focal = focal / 2.0 imgs = [ torch.from_numpy( cv2.resize(imgs[i], dsize=(400, 400), interpolation=cv2.INTER_AREA) diff --git a/pytorch3d/implicitron/dataset/single_sequence_dataset.py b/pytorch3d/implicitron/dataset/single_sequence_dataset.py index 0b28e6a7..7cdccd23 100644 --- a/pytorch3d/implicitron/dataset/single_sequence_dataset.py +++ b/pytorch3d/implicitron/dataset/single_sequence_dataset.py @@ -169,7 +169,7 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase): def _interpret_blender_cameras( - poses: torch.Tensor, H: int, W: int, focal: float + poses: torch.Tensor, focal: float ) -> List[PerspectiveCameras]: """ Convert 4x4 matrices representing cameras in blender format @@ -177,6 +177,7 @@ def _interpret_blender_cameras( Args: poses: N x 3 x 4 camera matrices + focal: ndc space focal length """ pose_target_cameras = [] for pose_target in poses: @@ -191,8 +192,8 @@ def _interpret_blender_cameras( Rpt3, Tpt3 = mtx[:, :3].split([3, 1], dim=0) - focal_length_pt3 = torch.FloatTensor([[-focal, focal]]) - principal_point_pt3 = torch.FloatTensor([[W / 2, H / 2]]) + focal_length_pt3 = torch.FloatTensor([[focal, focal]]) + principal_point_pt3 = torch.FloatTensor([[0.0, 0.0]]) cameras = PerspectiveCameras( focal_length=focal_length_pt3, diff --git a/tests/implicitron/test_data_llff.py b/tests/implicitron/test_data_llff.py index 82c1eb3d..9f73478b 100644 --- a/tests/implicitron/test_data_llff.py +++ b/tests/implicitron/test_data_llff.py @@ -7,6 +7,7 @@ import os import unittest +import torch from pytorch3d.implicitron.dataset.blender_dataset_map_provider import ( BlenderDatasetMapProvider, ) @@ -37,6 +38,11 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase): object_name="lego", ) dataset_map = provider.get_dataset_map() + known_matrix = torch.zeros(1, 4, 4) + known_matrix[0, 0, 0] = 2.7778 + known_matrix[0, 1, 1] = 2.7778 + known_matrix[0, 2, 3] = 1 + known_matrix[0, 3, 2] = 1 for name, length in [("train", 100), ("val", 100), ("test", 200)]: dataset = getattr(dataset_map, name) @@ -48,6 +54,11 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase): # corner of image is background self.assertEqual(value.fg_probability[0, 0, 0], 0) self.assertEqual(value.fg_probability.max(), 1.0) + self.assertIsInstance(value.camera, PerspectiveCameras) + self.assertEqual(len(value.camera), 1) + self.assertIsNone(value.camera.K) + matrix = value.camera.get_projection_transform().get_matrix() + self.assertClose(matrix, known_matrix, atol=1e-4) self.assertIsInstance(value, FrameData) def test_llff(self): @@ -60,6 +71,11 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase): object_name="fern", ) dataset_map = provider.get_dataset_map() + known_matrix = torch.zeros(1, 4, 4) + known_matrix[0, 0, 0] = 2.1564 + known_matrix[0, 1, 1] = 2.1564 + known_matrix[0, 2, 3] = 1 + known_matrix[0, 3, 2] = 1 for name, length, frame_type in [ ("train", 17, "known"), @@ -73,6 +89,11 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase): self.assertIsInstance(value, FrameData) self.assertEqual(value.frame_type, frame_type) self.assertEqual(value.image_rgb.shape, (3, 378, 504)) + self.assertIsInstance(value.camera, PerspectiveCameras) + self.assertEqual(len(value.camera), 1) + self.assertIsNone(value.camera.K) + matrix = value.camera.get_projection_transform().get_matrix() + self.assertClose(matrix, known_matrix, atol=1e-4) self.assertEqual(len(dataset_map.test.get_eval_batches()), 3) for batch in dataset_map.test.get_eval_batches():