From 731ea53c803ce00fbfd7531b4f1c6bf263a1f6b4 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Wed, 22 Jun 2022 05:54:54 -0700 Subject: [PATCH] Llff & blender convention fix Summary: Images were coming out in the wrong format. Reviewed By: shapovalov Differential Revision: D37291278 fbshipit-source-id: c10871c37dd186982e7abf2071ac66ed583df2e6 --- pytorch3d/implicitron/dataset/blender_dataset_map_provider.py | 2 +- pytorch3d/implicitron/dataset/llff_dataset_map_provider.py | 2 +- tests/implicitron/test_data_llff.py | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch3d/implicitron/dataset/blender_dataset_map_provider.py b/pytorch3d/implicitron/dataset/blender_dataset_map_provider.py index c37a3a60..f9f217af 100644 --- a/pytorch3d/implicitron/dataset/blender_dataset_map_provider.py +++ b/pytorch3d/implicitron/dataset/blender_dataset_map_provider.py @@ -42,7 +42,7 @@ class BlenderDatasetMapProvider(SingleSceneDatasetMapProviderBase): ) H, W, focal = hwf H, W = int(H), int(W) - images = torch.from_numpy(images) + images = torch.from_numpy(images).permute(0, 3, 1, 2)[:, :3] # pyre-ignore[16] self.poses = _interpret_blender_cameras(poses, H, W, focal) diff --git a/pytorch3d/implicitron/dataset/llff_dataset_map_provider.py b/pytorch3d/implicitron/dataset/llff_dataset_map_provider.py index c4e180f3..a273fa56 100644 --- a/pytorch3d/implicitron/dataset/llff_dataset_map_provider.py +++ b/pytorch3d/implicitron/dataset/llff_dataset_map_provider.py @@ -50,7 +50,7 @@ class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase): i_split = (i_train, i_test, i_test) H, W, focal = hwf H, W = int(H), int(W) - images = torch.from_numpy(images) + images = torch.from_numpy(images).permute(0, 3, 1, 2) poses = torch.from_numpy(poses) # pyre-ignore[16] diff --git a/tests/implicitron/test_data_llff.py b/tests/implicitron/test_data_llff.py index 7dd69245..271e4e7b 100644 --- a/tests/implicitron/test_data_llff.py +++ b/tests/implicitron/test_data_llff.py @@ -40,6 +40,7 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase): self.assertEqual(len(dataset), length) # try getting a value value = dataset[0] + self.assertEqual(value.image_rgb.shape, (3, 800, 800)) self.assertIsInstance(value, FrameData) def test_llff(self): @@ -62,6 +63,7 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase): value = dataset[0] self.assertIsInstance(value, FrameData) self.assertEqual(value.frame_type, frame_type) + self.assertEqual(value.image_rgb.shape, (3, 378, 504)) self.assertEqual(len(dataset_map.test.get_eval_batches()), 3) for batch in dataset_map.test.get_eval_batches():