mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
provide fg_probability for blender data
Summary: The blender synthetic dataset contains object masks in the alpha channel. Provide these in the corresponding dataset. Reviewed By: shapovalov Differential Revision: D37344380 fbshipit-source-id: 3ddacad9d667c0fa0ae5a61fb1d2ffc806c9abf3
This commit is contained in:
parent
731ea53c80
commit
3e4fb0b9d9
@ -42,11 +42,13 @@ class BlenderDatasetMapProvider(SingleSceneDatasetMapProviderBase):
|
||||
)
|
||||
H, W, focal = hwf
|
||||
H, W = int(H), int(W)
|
||||
images = torch.from_numpy(images).permute(0, 3, 1, 2)[:, :3]
|
||||
images_masks = torch.from_numpy(images).permute(0, 3, 1, 2)
|
||||
|
||||
# pyre-ignore[16]
|
||||
self.poses = _interpret_blender_cameras(poses, H, W, focal)
|
||||
# pyre-ignore[16]
|
||||
self.images = images
|
||||
self.images = images_masks[:, :3]
|
||||
# pyre-ignore[16]
|
||||
self.fg_probabilities = images_masks[:, 3:4]
|
||||
# pyre-ignore[16]
|
||||
self.i_split = i_split
|
||||
|
@ -58,4 +58,6 @@ class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase):
|
||||
# pyre-ignore[16]
|
||||
self.images = images
|
||||
# pyre-ignore[16]
|
||||
self.fg_probabilities = None
|
||||
# pyre-ignore[16]
|
||||
self.i_split = i_split
|
||||
|
@ -38,6 +38,7 @@ class SingleSceneDataset(DatasetBase, Configurable):
|
||||
"""
|
||||
|
||||
images: List[torch.Tensor] = field()
|
||||
fg_probabilities: Optional[List[torch.Tensor]] = field()
|
||||
poses: List[PerspectiveCameras] = field()
|
||||
object_name: str = field()
|
||||
frame_types: List[str] = field()
|
||||
@ -55,6 +56,9 @@ class SingleSceneDataset(DatasetBase, Configurable):
|
||||
image = self.images[index]
|
||||
pose = self.poses[index]
|
||||
frame_type = self.frame_types[index]
|
||||
fg_probability = (
|
||||
None if self.fg_probabilities is None else self.fg_probabilities[index]
|
||||
)
|
||||
|
||||
frame_data = FrameData(
|
||||
frame_number=index,
|
||||
@ -63,6 +67,7 @@ class SingleSceneDataset(DatasetBase, Configurable):
|
||||
camera=pose,
|
||||
image_size_hw=torch.tensor(image.shape[1:]),
|
||||
image_rgb=image,
|
||||
fg_probability=fg_probability,
|
||||
frame_type=frame_type,
|
||||
)
|
||||
return frame_data
|
||||
@ -100,7 +105,11 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
||||
|
||||
def _load_data(self) -> None:
|
||||
# This must be defined by each subclass,
|
||||
# and should set poses, images and i_split on self.
|
||||
# and should set the following on self.
|
||||
# - poses: a list of length-1 camera objects
|
||||
# - images: [N, 3, H, W] tensor of rgb images - floats in [0,1]
|
||||
# - fg_probabilities: None or [N, 1, H, W] of floats in [0,1]
|
||||
# - splits: List[List[int]] of indices for train/val/test subsets.
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_dataset(
|
||||
@ -110,6 +119,12 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
||||
# pyre-ignore[16]
|
||||
split = self.i_split[split_idx]
|
||||
frame_types = [frame_type] * len(split)
|
||||
fg_probabilities = (
|
||||
None
|
||||
# pyre-ignore[16]
|
||||
if self.fg_probabilities is None
|
||||
else self.fg_probabilities[split]
|
||||
)
|
||||
eval_batches = [[i] for i in range(len(split))]
|
||||
if split_idx != 0 and self.n_known_frames_for_test is not None:
|
||||
train_split = self.i_split[0]
|
||||
@ -130,6 +145,7 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
||||
object_name=self.object_name,
|
||||
# pyre-ignore[16]
|
||||
images=self.images[split],
|
||||
fg_probabilities=fg_probabilities,
|
||||
# pyre-ignore[16]
|
||||
poses=[self.poses[i] for i in split],
|
||||
frame_types=frame_types,
|
||||
|
@ -41,6 +41,10 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
|
||||
# try getting a value
|
||||
value = dataset[0]
|
||||
self.assertEqual(value.image_rgb.shape, (3, 800, 800))
|
||||
self.assertEqual(value.fg_probability.shape, (1, 800, 800))
|
||||
# corner of image is background
|
||||
self.assertEqual(value.fg_probability[0, 0, 0], 0)
|
||||
self.assertEqual(value.fg_probability.max(), 1.0)
|
||||
self.assertIsInstance(value, FrameData)
|
||||
|
||||
def test_llff(self):
|
||||
@ -90,6 +94,7 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
|
||||
for i, frame_type in enumerate(types):
|
||||
value = dataset[i]
|
||||
self.assertEqual(value.frame_type, frame_type)
|
||||
self.assertIsNone(value.fg_probability)
|
||||
|
||||
self.assertEqual(len(dataset_map.test.get_eval_batches()), 3)
|
||||
for batch in dataset_map.test.get_eval_batches():
|
||||
|
Loading…
x
Reference in New Issue
Block a user