mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
fix ndc/screen problem in blender/llff (#39)
Summary: X-link: https://github.com/fairinternal/pytorch3d/pull/39 Blender and LLFF cameras were sending screen space focal length and principal point to a camera init function expecting NDC Reviewed By: shapovalov Differential Revision: D37788686 fbshipit-source-id: 2ddf7436248bc0d174eceb04c288b93858138582
This commit is contained in:
parent
67840f8320
commit
38fd8380f7
@ -41,11 +41,10 @@ class BlenderDatasetMapProvider(SingleSceneDatasetMapProviderBase):
|
||||
path_manager=path_manager,
|
||||
)
|
||||
H, W, focal = hwf
|
||||
H, W = int(H), int(W)
|
||||
images_masks = torch.from_numpy(images).permute(0, 3, 1, 2)
|
||||
|
||||
# pyre-ignore[16]
|
||||
self.poses = _interpret_blender_cameras(poses, H, W, focal)
|
||||
self.poses = _interpret_blender_cameras(poses, focal)
|
||||
# pyre-ignore[16]
|
||||
self.images = images_masks[:, :3]
|
||||
# pyre-ignore[16]
|
||||
|
@ -49,12 +49,12 @@ class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase):
|
||||
)
|
||||
i_split = (i_train, i_test, i_test)
|
||||
H, W, focal = hwf
|
||||
H, W = int(H), int(W)
|
||||
focal_ndc = 2 * focal / min(H, W)
|
||||
images = torch.from_numpy(images).permute(0, 3, 1, 2)
|
||||
poses = torch.from_numpy(poses)
|
||||
|
||||
# pyre-ignore[16]
|
||||
self.poses = _interpret_blender_cameras(poses, H, W, focal)
|
||||
self.poses = _interpret_blender_cameras(poses, focal_ndc)
|
||||
# pyre-ignore[16]
|
||||
self.images = images
|
||||
# pyre-ignore[16]
|
||||
|
@ -46,7 +46,12 @@ def _local_path(path_manager, path):
|
||||
|
||||
|
||||
def load_blender_data(
|
||||
basedir, half_res=False, testskip=1, debug=False, path_manager=None
|
||||
basedir,
|
||||
half_res=False,
|
||||
testskip=1,
|
||||
debug=False,
|
||||
path_manager=None,
|
||||
focal_length_in_screen_space=False,
|
||||
):
|
||||
splits = ["train", "val", "test"]
|
||||
metas = {}
|
||||
@ -84,7 +89,10 @@ def load_blender_data(
|
||||
|
||||
H, W = imgs[0].shape[:2]
|
||||
camera_angle_x = float(meta["camera_angle_x"])
|
||||
focal = 0.5 * W / np.tan(0.5 * camera_angle_x)
|
||||
if focal_length_in_screen_space:
|
||||
focal = 0.5 * W / np.tan(0.5 * camera_angle_x)
|
||||
else:
|
||||
focal = 1 / np.tan(0.5 * camera_angle_x)
|
||||
|
||||
render_poses = torch.stack(
|
||||
[
|
||||
@ -100,7 +108,8 @@ def load_blender_data(
|
||||
|
||||
H = H // 32
|
||||
W = W // 32
|
||||
focal = focal / 32.0
|
||||
if focal_length_in_screen_space:
|
||||
focal = focal / 32.0
|
||||
imgs = [
|
||||
torch.from_numpy(
|
||||
cv2.resize(imgs[i], dsize=(25, 25), interpolation=cv2.INTER_AREA)
|
||||
@ -117,7 +126,8 @@ def load_blender_data(
|
||||
# TODO: resize images using INTER_AREA (cv2)
|
||||
H = H // 2
|
||||
W = W // 2
|
||||
focal = focal / 2.0
|
||||
if focal_length_in_screen_space:
|
||||
focal = focal / 2.0
|
||||
imgs = [
|
||||
torch.from_numpy(
|
||||
cv2.resize(imgs[i], dsize=(400, 400), interpolation=cv2.INTER_AREA)
|
||||
|
@ -169,7 +169,7 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
||||
|
||||
|
||||
def _interpret_blender_cameras(
|
||||
poses: torch.Tensor, H: int, W: int, focal: float
|
||||
poses: torch.Tensor, focal: float
|
||||
) -> List[PerspectiveCameras]:
|
||||
"""
|
||||
Convert 4x4 matrices representing cameras in blender format
|
||||
@ -177,6 +177,7 @@ def _interpret_blender_cameras(
|
||||
|
||||
Args:
|
||||
poses: N x 3 x 4 camera matrices
|
||||
focal: ndc space focal length
|
||||
"""
|
||||
pose_target_cameras = []
|
||||
for pose_target in poses:
|
||||
@ -191,8 +192,8 @@ def _interpret_blender_cameras(
|
||||
|
||||
Rpt3, Tpt3 = mtx[:, :3].split([3, 1], dim=0)
|
||||
|
||||
focal_length_pt3 = torch.FloatTensor([[-focal, focal]])
|
||||
principal_point_pt3 = torch.FloatTensor([[W / 2, H / 2]])
|
||||
focal_length_pt3 = torch.FloatTensor([[focal, focal]])
|
||||
principal_point_pt3 = torch.FloatTensor([[0.0, 0.0]])
|
||||
|
||||
cameras = PerspectiveCameras(
|
||||
focal_length=focal_length_pt3,
|
||||
|
@ -7,6 +7,7 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.dataset.blender_dataset_map_provider import (
|
||||
BlenderDatasetMapProvider,
|
||||
)
|
||||
@ -37,6 +38,11 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
|
||||
object_name="lego",
|
||||
)
|
||||
dataset_map = provider.get_dataset_map()
|
||||
known_matrix = torch.zeros(1, 4, 4)
|
||||
known_matrix[0, 0, 0] = 2.7778
|
||||
known_matrix[0, 1, 1] = 2.7778
|
||||
known_matrix[0, 2, 3] = 1
|
||||
known_matrix[0, 3, 2] = 1
|
||||
|
||||
for name, length in [("train", 100), ("val", 100), ("test", 200)]:
|
||||
dataset = getattr(dataset_map, name)
|
||||
@ -48,6 +54,11 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
|
||||
# corner of image is background
|
||||
self.assertEqual(value.fg_probability[0, 0, 0], 0)
|
||||
self.assertEqual(value.fg_probability.max(), 1.0)
|
||||
self.assertIsInstance(value.camera, PerspectiveCameras)
|
||||
self.assertEqual(len(value.camera), 1)
|
||||
self.assertIsNone(value.camera.K)
|
||||
matrix = value.camera.get_projection_transform().get_matrix()
|
||||
self.assertClose(matrix, known_matrix, atol=1e-4)
|
||||
self.assertIsInstance(value, FrameData)
|
||||
|
||||
def test_llff(self):
|
||||
@ -60,6 +71,11 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
|
||||
object_name="fern",
|
||||
)
|
||||
dataset_map = provider.get_dataset_map()
|
||||
known_matrix = torch.zeros(1, 4, 4)
|
||||
known_matrix[0, 0, 0] = 2.1564
|
||||
known_matrix[0, 1, 1] = 2.1564
|
||||
known_matrix[0, 2, 3] = 1
|
||||
known_matrix[0, 3, 2] = 1
|
||||
|
||||
for name, length, frame_type in [
|
||||
("train", 17, "known"),
|
||||
@ -73,6 +89,11 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
|
||||
self.assertIsInstance(value, FrameData)
|
||||
self.assertEqual(value.frame_type, frame_type)
|
||||
self.assertEqual(value.image_rgb.shape, (3, 378, 504))
|
||||
self.assertIsInstance(value.camera, PerspectiveCameras)
|
||||
self.assertEqual(len(value.camera), 1)
|
||||
self.assertIsNone(value.camera.K)
|
||||
matrix = value.camera.get_projection_transform().get_matrix()
|
||||
self.assertClose(matrix, known_matrix, atol=1e-4)
|
||||
|
||||
self.assertEqual(len(dataset_map.test.get_eval_batches()), 3)
|
||||
for batch in dataset_map.test.get_eval_batches():
|
||||
|
Loading…
x
Reference in New Issue
Block a user