mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	fix ndc/screen problem in blender/llff (#39)
Summary: X-link: https://github.com/fairinternal/pytorch3d/pull/39 Blender and LLFF cameras were sending screen space focal length and principal point to a camera init function expecting NDC Reviewed By: shapovalov Differential Revision: D37788686 fbshipit-source-id: 2ddf7436248bc0d174eceb04c288b93858138582
This commit is contained in:
		
							parent
							
								
									67840f8320
								
							
						
					
					
						commit
						38fd8380f7
					
				@ -41,11 +41,10 @@ class BlenderDatasetMapProvider(SingleSceneDatasetMapProviderBase):
 | 
			
		||||
            path_manager=path_manager,
 | 
			
		||||
        )
 | 
			
		||||
        H, W, focal = hwf
 | 
			
		||||
        H, W = int(H), int(W)
 | 
			
		||||
        images_masks = torch.from_numpy(images).permute(0, 3, 1, 2)
 | 
			
		||||
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.poses = _interpret_blender_cameras(poses, H, W, focal)
 | 
			
		||||
        self.poses = _interpret_blender_cameras(poses, focal)
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.images = images_masks[:, :3]
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
 | 
			
		||||
@ -49,12 +49,12 @@ class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase):
 | 
			
		||||
        )
 | 
			
		||||
        i_split = (i_train, i_test, i_test)
 | 
			
		||||
        H, W, focal = hwf
 | 
			
		||||
        H, W = int(H), int(W)
 | 
			
		||||
        focal_ndc = 2 * focal / min(H, W)
 | 
			
		||||
        images = torch.from_numpy(images).permute(0, 3, 1, 2)
 | 
			
		||||
        poses = torch.from_numpy(poses)
 | 
			
		||||
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.poses = _interpret_blender_cameras(poses, H, W, focal)
 | 
			
		||||
        self.poses = _interpret_blender_cameras(poses, focal_ndc)
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.images = images
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
 | 
			
		||||
@ -46,7 +46,12 @@ def _local_path(path_manager, path):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_blender_data(
 | 
			
		||||
    basedir, half_res=False, testskip=1, debug=False, path_manager=None
 | 
			
		||||
    basedir,
 | 
			
		||||
    half_res=False,
 | 
			
		||||
    testskip=1,
 | 
			
		||||
    debug=False,
 | 
			
		||||
    path_manager=None,
 | 
			
		||||
    focal_length_in_screen_space=False,
 | 
			
		||||
):
 | 
			
		||||
    splits = ["train", "val", "test"]
 | 
			
		||||
    metas = {}
 | 
			
		||||
@ -84,7 +89,10 @@ def load_blender_data(
 | 
			
		||||
 | 
			
		||||
    H, W = imgs[0].shape[:2]
 | 
			
		||||
    camera_angle_x = float(meta["camera_angle_x"])
 | 
			
		||||
    focal = 0.5 * W / np.tan(0.5 * camera_angle_x)
 | 
			
		||||
    if focal_length_in_screen_space:
 | 
			
		||||
        focal = 0.5 * W / np.tan(0.5 * camera_angle_x)
 | 
			
		||||
    else:
 | 
			
		||||
        focal = 1 / np.tan(0.5 * camera_angle_x)
 | 
			
		||||
 | 
			
		||||
    render_poses = torch.stack(
 | 
			
		||||
        [
 | 
			
		||||
@ -100,7 +108,8 @@ def load_blender_data(
 | 
			
		||||
 | 
			
		||||
        H = H // 32
 | 
			
		||||
        W = W // 32
 | 
			
		||||
        focal = focal / 32.0
 | 
			
		||||
        if focal_length_in_screen_space:
 | 
			
		||||
            focal = focal / 32.0
 | 
			
		||||
        imgs = [
 | 
			
		||||
            torch.from_numpy(
 | 
			
		||||
                cv2.resize(imgs[i], dsize=(25, 25), interpolation=cv2.INTER_AREA)
 | 
			
		||||
@ -117,7 +126,8 @@ def load_blender_data(
 | 
			
		||||
        # TODO: resize images using INTER_AREA (cv2)
 | 
			
		||||
        H = H // 2
 | 
			
		||||
        W = W // 2
 | 
			
		||||
        focal = focal / 2.0
 | 
			
		||||
        if focal_length_in_screen_space:
 | 
			
		||||
            focal = focal / 2.0
 | 
			
		||||
        imgs = [
 | 
			
		||||
            torch.from_numpy(
 | 
			
		||||
                cv2.resize(imgs[i], dsize=(400, 400), interpolation=cv2.INTER_AREA)
 | 
			
		||||
 | 
			
		||||
@ -169,7 +169,7 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _interpret_blender_cameras(
 | 
			
		||||
    poses: torch.Tensor, H: int, W: int, focal: float
 | 
			
		||||
    poses: torch.Tensor, focal: float
 | 
			
		||||
) -> List[PerspectiveCameras]:
 | 
			
		||||
    """
 | 
			
		||||
    Convert 4x4 matrices representing cameras in blender format
 | 
			
		||||
@ -177,6 +177,7 @@ def _interpret_blender_cameras(
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        poses: N x 3 x 4 camera matrices
 | 
			
		||||
        focal: ndc space focal length
 | 
			
		||||
    """
 | 
			
		||||
    pose_target_cameras = []
 | 
			
		||||
    for pose_target in poses:
 | 
			
		||||
@ -191,8 +192,8 @@ def _interpret_blender_cameras(
 | 
			
		||||
 | 
			
		||||
        Rpt3, Tpt3 = mtx[:, :3].split([3, 1], dim=0)
 | 
			
		||||
 | 
			
		||||
        focal_length_pt3 = torch.FloatTensor([[-focal, focal]])
 | 
			
		||||
        principal_point_pt3 = torch.FloatTensor([[W / 2, H / 2]])
 | 
			
		||||
        focal_length_pt3 = torch.FloatTensor([[focal, focal]])
 | 
			
		||||
        principal_point_pt3 = torch.FloatTensor([[0.0, 0.0]])
 | 
			
		||||
 | 
			
		||||
        cameras = PerspectiveCameras(
 | 
			
		||||
            focal_length=focal_length_pt3,
 | 
			
		||||
 | 
			
		||||
@ -7,6 +7,7 @@
 | 
			
		||||
import os
 | 
			
		||||
import unittest
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.dataset.blender_dataset_map_provider import (
 | 
			
		||||
    BlenderDatasetMapProvider,
 | 
			
		||||
)
 | 
			
		||||
@ -37,6 +38,11 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            object_name="lego",
 | 
			
		||||
        )
 | 
			
		||||
        dataset_map = provider.get_dataset_map()
 | 
			
		||||
        known_matrix = torch.zeros(1, 4, 4)
 | 
			
		||||
        known_matrix[0, 0, 0] = 2.7778
 | 
			
		||||
        known_matrix[0, 1, 1] = 2.7778
 | 
			
		||||
        known_matrix[0, 2, 3] = 1
 | 
			
		||||
        known_matrix[0, 3, 2] = 1
 | 
			
		||||
 | 
			
		||||
        for name, length in [("train", 100), ("val", 100), ("test", 200)]:
 | 
			
		||||
            dataset = getattr(dataset_map, name)
 | 
			
		||||
@ -48,6 +54,11 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            # corner of image is background
 | 
			
		||||
            self.assertEqual(value.fg_probability[0, 0, 0], 0)
 | 
			
		||||
            self.assertEqual(value.fg_probability.max(), 1.0)
 | 
			
		||||
            self.assertIsInstance(value.camera, PerspectiveCameras)
 | 
			
		||||
            self.assertEqual(len(value.camera), 1)
 | 
			
		||||
            self.assertIsNone(value.camera.K)
 | 
			
		||||
            matrix = value.camera.get_projection_transform().get_matrix()
 | 
			
		||||
            self.assertClose(matrix, known_matrix, atol=1e-4)
 | 
			
		||||
            self.assertIsInstance(value, FrameData)
 | 
			
		||||
 | 
			
		||||
    def test_llff(self):
 | 
			
		||||
@ -60,6 +71,11 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            object_name="fern",
 | 
			
		||||
        )
 | 
			
		||||
        dataset_map = provider.get_dataset_map()
 | 
			
		||||
        known_matrix = torch.zeros(1, 4, 4)
 | 
			
		||||
        known_matrix[0, 0, 0] = 2.1564
 | 
			
		||||
        known_matrix[0, 1, 1] = 2.1564
 | 
			
		||||
        known_matrix[0, 2, 3] = 1
 | 
			
		||||
        known_matrix[0, 3, 2] = 1
 | 
			
		||||
 | 
			
		||||
        for name, length, frame_type in [
 | 
			
		||||
            ("train", 17, "known"),
 | 
			
		||||
@ -73,6 +89,11 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
            self.assertIsInstance(value, FrameData)
 | 
			
		||||
            self.assertEqual(value.frame_type, frame_type)
 | 
			
		||||
            self.assertEqual(value.image_rgb.shape, (3, 378, 504))
 | 
			
		||||
            self.assertIsInstance(value.camera, PerspectiveCameras)
 | 
			
		||||
            self.assertEqual(len(value.camera), 1)
 | 
			
		||||
            self.assertIsNone(value.camera.K)
 | 
			
		||||
            matrix = value.camera.get_projection_transform().get_matrix()
 | 
			
		||||
            self.assertClose(matrix, known_matrix, atol=1e-4)
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(len(dataset_map.test.get_eval_batches()), 3)
 | 
			
		||||
        for batch in dataset_map.test.get_eval_batches():
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user