fix ndc/screen problem in blender/llff (#39)

Summary:
X-link: https://github.com/fairinternal/pytorch3d/pull/39

Blender and LLFF cameras were sending screen space focal length and principal point to a camera init function expecting NDC

Reviewed By: shapovalov

Differential Revision: D37788686

fbshipit-source-id: 2ddf7436248bc0d174eceb04c288b93858138582
This commit is contained in:
Jeremy Reizenstein 2022-07-19 10:38:13 -07:00 committed by Facebook GitHub Bot
parent 67840f8320
commit 38fd8380f7
5 changed files with 42 additions and 11 deletions

View File

@ -41,11 +41,10 @@ class BlenderDatasetMapProvider(SingleSceneDatasetMapProviderBase):
path_manager=path_manager,
)
H, W, focal = hwf
H, W = int(H), int(W)
images_masks = torch.from_numpy(images).permute(0, 3, 1, 2)
# pyre-ignore[16]
self.poses = _interpret_blender_cameras(poses, H, W, focal)
self.poses = _interpret_blender_cameras(poses, focal)
# pyre-ignore[16]
self.images = images_masks[:, :3]
# pyre-ignore[16]

View File

@ -49,12 +49,12 @@ class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase):
)
i_split = (i_train, i_test, i_test)
H, W, focal = hwf
H, W = int(H), int(W)
focal_ndc = 2 * focal / min(H, W)
images = torch.from_numpy(images).permute(0, 3, 1, 2)
poses = torch.from_numpy(poses)
# pyre-ignore[16]
self.poses = _interpret_blender_cameras(poses, H, W, focal)
self.poses = _interpret_blender_cameras(poses, focal_ndc)
# pyre-ignore[16]
self.images = images
# pyre-ignore[16]

View File

@ -46,7 +46,12 @@ def _local_path(path_manager, path):
def load_blender_data(
basedir, half_res=False, testskip=1, debug=False, path_manager=None
basedir,
half_res=False,
testskip=1,
debug=False,
path_manager=None,
focal_length_in_screen_space=False,
):
splits = ["train", "val", "test"]
metas = {}
@ -84,7 +89,10 @@ def load_blender_data(
H, W = imgs[0].shape[:2]
camera_angle_x = float(meta["camera_angle_x"])
focal = 0.5 * W / np.tan(0.5 * camera_angle_x)
if focal_length_in_screen_space:
focal = 0.5 * W / np.tan(0.5 * camera_angle_x)
else:
focal = 1 / np.tan(0.5 * camera_angle_x)
render_poses = torch.stack(
[
@ -100,7 +108,8 @@ def load_blender_data(
H = H // 32
W = W // 32
focal = focal / 32.0
if focal_length_in_screen_space:
focal = focal / 32.0
imgs = [
torch.from_numpy(
cv2.resize(imgs[i], dsize=(25, 25), interpolation=cv2.INTER_AREA)
@ -117,7 +126,8 @@ def load_blender_data(
# TODO: resize images using INTER_AREA (cv2)
H = H // 2
W = W // 2
focal = focal / 2.0
if focal_length_in_screen_space:
focal = focal / 2.0
imgs = [
torch.from_numpy(
cv2.resize(imgs[i], dsize=(400, 400), interpolation=cv2.INTER_AREA)

View File

@ -169,7 +169,7 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
def _interpret_blender_cameras(
poses: torch.Tensor, H: int, W: int, focal: float
poses: torch.Tensor, focal: float
) -> List[PerspectiveCameras]:
"""
Convert 4x4 matrices representing cameras in blender format
@ -177,6 +177,7 @@ def _interpret_blender_cameras(
Args:
poses: N x 3 x 4 camera matrices
focal: ndc space focal length
"""
pose_target_cameras = []
for pose_target in poses:
@ -191,8 +192,8 @@ def _interpret_blender_cameras(
Rpt3, Tpt3 = mtx[:, :3].split([3, 1], dim=0)
focal_length_pt3 = torch.FloatTensor([[-focal, focal]])
principal_point_pt3 = torch.FloatTensor([[W / 2, H / 2]])
focal_length_pt3 = torch.FloatTensor([[focal, focal]])
principal_point_pt3 = torch.FloatTensor([[0.0, 0.0]])
cameras = PerspectiveCameras(
focal_length=focal_length_pt3,

View File

@ -7,6 +7,7 @@
import os
import unittest
import torch
from pytorch3d.implicitron.dataset.blender_dataset_map_provider import (
BlenderDatasetMapProvider,
)
@ -37,6 +38,11 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
object_name="lego",
)
dataset_map = provider.get_dataset_map()
known_matrix = torch.zeros(1, 4, 4)
known_matrix[0, 0, 0] = 2.7778
known_matrix[0, 1, 1] = 2.7778
known_matrix[0, 2, 3] = 1
known_matrix[0, 3, 2] = 1
for name, length in [("train", 100), ("val", 100), ("test", 200)]:
dataset = getattr(dataset_map, name)
@ -48,6 +54,11 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
# corner of image is background
self.assertEqual(value.fg_probability[0, 0, 0], 0)
self.assertEqual(value.fg_probability.max(), 1.0)
self.assertIsInstance(value.camera, PerspectiveCameras)
self.assertEqual(len(value.camera), 1)
self.assertIsNone(value.camera.K)
matrix = value.camera.get_projection_transform().get_matrix()
self.assertClose(matrix, known_matrix, atol=1e-4)
self.assertIsInstance(value, FrameData)
def test_llff(self):
@ -60,6 +71,11 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
object_name="fern",
)
dataset_map = provider.get_dataset_map()
known_matrix = torch.zeros(1, 4, 4)
known_matrix[0, 0, 0] = 2.1564
known_matrix[0, 1, 1] = 2.1564
known_matrix[0, 2, 3] = 1
known_matrix[0, 3, 2] = 1
for name, length, frame_type in [
("train", 17, "known"),
@ -73,6 +89,11 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
self.assertIsInstance(value, FrameData)
self.assertEqual(value.frame_type, frame_type)
self.assertEqual(value.image_rgb.shape, (3, 378, 504))
self.assertIsInstance(value.camera, PerspectiveCameras)
self.assertEqual(len(value.camera), 1)
self.assertIsNone(value.camera.K)
matrix = value.camera.get_projection_transform().get_matrix()
self.assertClose(matrix, known_matrix, atol=1e-4)
self.assertEqual(len(dataset_map.test.get_eval_batches()), 3)
for batch in dataset_map.test.get_eval_batches():