From 3e4fb0b9d924951b17ce70b847b597995895cfc6 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Wed, 22 Jun 2022 06:11:50 -0700 Subject: [PATCH] provide fg_probability for blender data Summary: The blender synthetic dataset contains object masks in the alpha channel. Provide these in the corresponding dataset. Reviewed By: shapovalov Differential Revision: D37344380 fbshipit-source-id: 3ddacad9d667c0fa0ae5a61fb1d2ffc806c9abf3 --- .../dataset/blender_dataset_map_provider.py | 6 ++++-- .../dataset/llff_dataset_map_provider.py | 2 ++ .../dataset/single_sequence_dataset.py | 18 +++++++++++++++++- tests/implicitron/test_data_llff.py | 5 +++++ 4 files changed, 28 insertions(+), 3 deletions(-) diff --git a/pytorch3d/implicitron/dataset/blender_dataset_map_provider.py b/pytorch3d/implicitron/dataset/blender_dataset_map_provider.py index f9f217af..c06b4313 100644 --- a/pytorch3d/implicitron/dataset/blender_dataset_map_provider.py +++ b/pytorch3d/implicitron/dataset/blender_dataset_map_provider.py @@ -42,11 +42,13 @@ class BlenderDatasetMapProvider(SingleSceneDatasetMapProviderBase): ) H, W, focal = hwf H, W = int(H), int(W) - images = torch.from_numpy(images).permute(0, 3, 1, 2)[:, :3] + images_masks = torch.from_numpy(images).permute(0, 3, 1, 2) # pyre-ignore[16] self.poses = _interpret_blender_cameras(poses, H, W, focal) # pyre-ignore[16] - self.images = images + self.images = images_masks[:, :3] + # pyre-ignore[16] + self.fg_probabilities = images_masks[:, 3:4] # pyre-ignore[16] self.i_split = i_split diff --git a/pytorch3d/implicitron/dataset/llff_dataset_map_provider.py b/pytorch3d/implicitron/dataset/llff_dataset_map_provider.py index a273fa56..d9ea2917 100644 --- a/pytorch3d/implicitron/dataset/llff_dataset_map_provider.py +++ b/pytorch3d/implicitron/dataset/llff_dataset_map_provider.py @@ -58,4 +58,6 @@ class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase): # pyre-ignore[16] self.images = images # pyre-ignore[16] + self.fg_probabilities = None + # pyre-ignore[16] self.i_split = i_split diff --git a/pytorch3d/implicitron/dataset/single_sequence_dataset.py b/pytorch3d/implicitron/dataset/single_sequence_dataset.py index 6f13f307..a9757d72 100644 --- a/pytorch3d/implicitron/dataset/single_sequence_dataset.py +++ b/pytorch3d/implicitron/dataset/single_sequence_dataset.py @@ -38,6 +38,7 @@ class SingleSceneDataset(DatasetBase, Configurable): """ images: List[torch.Tensor] = field() + fg_probabilities: Optional[List[torch.Tensor]] = field() poses: List[PerspectiveCameras] = field() object_name: str = field() frame_types: List[str] = field() @@ -55,6 +56,9 @@ class SingleSceneDataset(DatasetBase, Configurable): image = self.images[index] pose = self.poses[index] frame_type = self.frame_types[index] + fg_probability = ( + None if self.fg_probabilities is None else self.fg_probabilities[index] + ) frame_data = FrameData( frame_number=index, @@ -63,6 +67,7 @@ class SingleSceneDataset(DatasetBase, Configurable): camera=pose, image_size_hw=torch.tensor(image.shape[1:]), image_rgb=image, + fg_probability=fg_probability, frame_type=frame_type, ) return frame_data @@ -100,7 +105,11 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase): def _load_data(self) -> None: # This must be defined by each subclass, - # and should set poses, images and i_split on self. + # and should set the following on self. + # - poses: a list of length-1 camera objects + # - images: [N, 3, H, W] tensor of rgb images - floats in [0,1] + # - fg_probabilities: None or [N, 1, H, W] of floats in [0,1] + # - splits: List[List[int]] of indices for train/val/test subsets. raise NotImplementedError def _get_dataset( @@ -110,6 +119,12 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase): # pyre-ignore[16] split = self.i_split[split_idx] frame_types = [frame_type] * len(split) + fg_probabilities = ( + None + # pyre-ignore[16] + if self.fg_probabilities is None + else self.fg_probabilities[split] + ) eval_batches = [[i] for i in range(len(split))] if split_idx != 0 and self.n_known_frames_for_test is not None: train_split = self.i_split[0] @@ -130,6 +145,7 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase): object_name=self.object_name, # pyre-ignore[16] images=self.images[split], + fg_probabilities=fg_probabilities, # pyre-ignore[16] poses=[self.poses[i] for i in split], frame_types=frame_types, diff --git a/tests/implicitron/test_data_llff.py b/tests/implicitron/test_data_llff.py index 271e4e7b..e77a1f1d 100644 --- a/tests/implicitron/test_data_llff.py +++ b/tests/implicitron/test_data_llff.py @@ -41,6 +41,10 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase): # try getting a value value = dataset[0] self.assertEqual(value.image_rgb.shape, (3, 800, 800)) + self.assertEqual(value.fg_probability.shape, (1, 800, 800)) + # corner of image is background + self.assertEqual(value.fg_probability[0, 0, 0], 0) + self.assertEqual(value.fg_probability.max(), 1.0) self.assertIsInstance(value, FrameData) def test_llff(self): @@ -90,6 +94,7 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase): for i, frame_type in enumerate(types): value = dataset[i] self.assertEqual(value.frame_type, frame_type) + self.assertIsNone(value.fg_probability) self.assertEqual(len(dataset_map.test.get_eval_batches()), 3) for batch in dataset_map.test.get_eval_batches():