fix ndc/screen problem in blender/llff (#39)

Summary:
X-link: https://github.com/fairinternal/pytorch3d/pull/39

Blender and LLFF cameras were sending screen space focal length and principal point to a camera init function expecting NDC

Reviewed By: shapovalov

Differential Revision: D37788686

fbshipit-source-id: 2ddf7436248bc0d174eceb04c288b93858138582
This commit is contained in:
Jeremy Reizenstein
2022-07-19 10:38:13 -07:00
committed by Facebook GitHub Bot
parent 67840f8320
commit 38fd8380f7
5 changed files with 42 additions and 11 deletions

View File

@@ -41,11 +41,10 @@ class BlenderDatasetMapProvider(SingleSceneDatasetMapProviderBase):
path_manager=path_manager,
)
H, W, focal = hwf
H, W = int(H), int(W)
images_masks = torch.from_numpy(images).permute(0, 3, 1, 2)
# pyre-ignore[16]
self.poses = _interpret_blender_cameras(poses, H, W, focal)
self.poses = _interpret_blender_cameras(poses, focal)
# pyre-ignore[16]
self.images = images_masks[:, :3]
# pyre-ignore[16]

View File

@@ -49,12 +49,12 @@ class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase):
)
i_split = (i_train, i_test, i_test)
H, W, focal = hwf
H, W = int(H), int(W)
focal_ndc = 2 * focal / min(H, W)
images = torch.from_numpy(images).permute(0, 3, 1, 2)
poses = torch.from_numpy(poses)
# pyre-ignore[16]
self.poses = _interpret_blender_cameras(poses, H, W, focal)
self.poses = _interpret_blender_cameras(poses, focal_ndc)
# pyre-ignore[16]
self.images = images
# pyre-ignore[16]

View File

@@ -46,7 +46,12 @@ def _local_path(path_manager, path):
def load_blender_data(
basedir, half_res=False, testskip=1, debug=False, path_manager=None
basedir,
half_res=False,
testskip=1,
debug=False,
path_manager=None,
focal_length_in_screen_space=False,
):
splits = ["train", "val", "test"]
metas = {}
@@ -84,7 +89,10 @@ def load_blender_data(
H, W = imgs[0].shape[:2]
camera_angle_x = float(meta["camera_angle_x"])
focal = 0.5 * W / np.tan(0.5 * camera_angle_x)
if focal_length_in_screen_space:
focal = 0.5 * W / np.tan(0.5 * camera_angle_x)
else:
focal = 1 / np.tan(0.5 * camera_angle_x)
render_poses = torch.stack(
[
@@ -100,7 +108,8 @@ def load_blender_data(
H = H // 32
W = W // 32
focal = focal / 32.0
if focal_length_in_screen_space:
focal = focal / 32.0
imgs = [
torch.from_numpy(
cv2.resize(imgs[i], dsize=(25, 25), interpolation=cv2.INTER_AREA)
@@ -117,7 +126,8 @@ def load_blender_data(
# TODO: resize images using INTER_AREA (cv2)
H = H // 2
W = W // 2
focal = focal / 2.0
if focal_length_in_screen_space:
focal = focal / 2.0
imgs = [
torch.from_numpy(
cv2.resize(imgs[i], dsize=(400, 400), interpolation=cv2.INTER_AREA)

View File

@@ -169,7 +169,7 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
def _interpret_blender_cameras(
poses: torch.Tensor, H: int, W: int, focal: float
poses: torch.Tensor, focal: float
) -> List[PerspectiveCameras]:
"""
Convert 4x4 matrices representing cameras in blender format
@@ -177,6 +177,7 @@ def _interpret_blender_cameras(
Args:
poses: N x 3 x 4 camera matrices
focal: ndc space focal length
"""
pose_target_cameras = []
for pose_target in poses:
@@ -191,8 +192,8 @@ def _interpret_blender_cameras(
Rpt3, Tpt3 = mtx[:, :3].split([3, 1], dim=0)
focal_length_pt3 = torch.FloatTensor([[-focal, focal]])
principal_point_pt3 = torch.FloatTensor([[W / 2, H / 2]])
focal_length_pt3 = torch.FloatTensor([[focal, focal]])
principal_point_pt3 = torch.FloatTensor([[0.0, 0.0]])
cameras = PerspectiveCameras(
focal_length=focal_length_pt3,