provide fg_probability for blender data

Summary: The blender synthetic dataset contains object masks in the alpha channel. Provide these in the corresponding dataset.

Reviewed By: shapovalov

Differential Revision: D37344380

fbshipit-source-id: 3ddacad9d667c0fa0ae5a61fb1d2ffc806c9abf3
This commit is contained in:
Jeremy Reizenstein 2022-06-22 06:11:50 -07:00 committed by Facebook GitHub Bot
parent 731ea53c80
commit 3e4fb0b9d9
4 changed files with 28 additions and 3 deletions

View File

@ -42,11 +42,13 @@ class BlenderDatasetMapProvider(SingleSceneDatasetMapProviderBase):
)
H, W, focal = hwf
H, W = int(H), int(W)
images = torch.from_numpy(images).permute(0, 3, 1, 2)[:, :3]
images_masks = torch.from_numpy(images).permute(0, 3, 1, 2)
# pyre-ignore[16]
self.poses = _interpret_blender_cameras(poses, H, W, focal)
# pyre-ignore[16]
self.images = images
self.images = images_masks[:, :3]
# pyre-ignore[16]
self.fg_probabilities = images_masks[:, 3:4]
# pyre-ignore[16]
self.i_split = i_split

View File

@ -58,4 +58,6 @@ class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase):
# pyre-ignore[16]
self.images = images
# pyre-ignore[16]
self.fg_probabilities = None
# pyre-ignore[16]
self.i_split = i_split

View File

@ -38,6 +38,7 @@ class SingleSceneDataset(DatasetBase, Configurable):
"""
images: List[torch.Tensor] = field()
fg_probabilities: Optional[List[torch.Tensor]] = field()
poses: List[PerspectiveCameras] = field()
object_name: str = field()
frame_types: List[str] = field()
@ -55,6 +56,9 @@ class SingleSceneDataset(DatasetBase, Configurable):
image = self.images[index]
pose = self.poses[index]
frame_type = self.frame_types[index]
fg_probability = (
None if self.fg_probabilities is None else self.fg_probabilities[index]
)
frame_data = FrameData(
frame_number=index,
@ -63,6 +67,7 @@ class SingleSceneDataset(DatasetBase, Configurable):
camera=pose,
image_size_hw=torch.tensor(image.shape[1:]),
image_rgb=image,
fg_probability=fg_probability,
frame_type=frame_type,
)
return frame_data
@ -100,7 +105,11 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
def _load_data(self) -> None:
# This must be defined by each subclass,
# and should set poses, images and i_split on self.
# and should set the following on self.
# - poses: a list of length-1 camera objects
# - images: [N, 3, H, W] tensor of rgb images - floats in [0,1]
# - fg_probabilities: None or [N, 1, H, W] of floats in [0,1]
# - splits: List[List[int]] of indices for train/val/test subsets.
raise NotImplementedError
def _get_dataset(
@ -110,6 +119,12 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
# pyre-ignore[16]
split = self.i_split[split_idx]
frame_types = [frame_type] * len(split)
fg_probabilities = (
None
# pyre-ignore[16]
if self.fg_probabilities is None
else self.fg_probabilities[split]
)
eval_batches = [[i] for i in range(len(split))]
if split_idx != 0 and self.n_known_frames_for_test is not None:
train_split = self.i_split[0]
@ -130,6 +145,7 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
object_name=self.object_name,
# pyre-ignore[16]
images=self.images[split],
fg_probabilities=fg_probabilities,
# pyre-ignore[16]
poses=[self.poses[i] for i in split],
frame_types=frame_types,

View File

@ -41,6 +41,10 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
# try getting a value
value = dataset[0]
self.assertEqual(value.image_rgb.shape, (3, 800, 800))
self.assertEqual(value.fg_probability.shape, (1, 800, 800))
# corner of image is background
self.assertEqual(value.fg_probability[0, 0, 0], 0)
self.assertEqual(value.fg_probability.max(), 1.0)
self.assertIsInstance(value, FrameData)
def test_llff(self):
@ -90,6 +94,7 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
for i, frame_type in enumerate(types):
value = dataset[i]
self.assertEqual(value.frame_type, frame_type)
self.assertIsNone(value.fg_probability)
self.assertEqual(len(dataset_map.test.get_eval_batches()), 3)
for batch in dataset_map.test.get_eval_batches():