mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
provide fg_probability for blender data
Summary: The blender synthetic dataset contains object masks in the alpha channel. Provide these in the corresponding dataset. Reviewed By: shapovalov Differential Revision: D37344380 fbshipit-source-id: 3ddacad9d667c0fa0ae5a61fb1d2ffc806c9abf3
This commit is contained in:
parent
731ea53c80
commit
3e4fb0b9d9
@ -42,11 +42,13 @@ class BlenderDatasetMapProvider(SingleSceneDatasetMapProviderBase):
|
|||||||
)
|
)
|
||||||
H, W, focal = hwf
|
H, W, focal = hwf
|
||||||
H, W = int(H), int(W)
|
H, W = int(H), int(W)
|
||||||
images = torch.from_numpy(images).permute(0, 3, 1, 2)[:, :3]
|
images_masks = torch.from_numpy(images).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
self.poses = _interpret_blender_cameras(poses, H, W, focal)
|
self.poses = _interpret_blender_cameras(poses, H, W, focal)
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
self.images = images
|
self.images = images_masks[:, :3]
|
||||||
|
# pyre-ignore[16]
|
||||||
|
self.fg_probabilities = images_masks[:, 3:4]
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
self.i_split = i_split
|
self.i_split = i_split
|
||||||
|
@ -58,4 +58,6 @@ class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase):
|
|||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
self.images = images
|
self.images = images
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
|
self.fg_probabilities = None
|
||||||
|
# pyre-ignore[16]
|
||||||
self.i_split = i_split
|
self.i_split = i_split
|
||||||
|
@ -38,6 +38,7 @@ class SingleSceneDataset(DatasetBase, Configurable):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
images: List[torch.Tensor] = field()
|
images: List[torch.Tensor] = field()
|
||||||
|
fg_probabilities: Optional[List[torch.Tensor]] = field()
|
||||||
poses: List[PerspectiveCameras] = field()
|
poses: List[PerspectiveCameras] = field()
|
||||||
object_name: str = field()
|
object_name: str = field()
|
||||||
frame_types: List[str] = field()
|
frame_types: List[str] = field()
|
||||||
@ -55,6 +56,9 @@ class SingleSceneDataset(DatasetBase, Configurable):
|
|||||||
image = self.images[index]
|
image = self.images[index]
|
||||||
pose = self.poses[index]
|
pose = self.poses[index]
|
||||||
frame_type = self.frame_types[index]
|
frame_type = self.frame_types[index]
|
||||||
|
fg_probability = (
|
||||||
|
None if self.fg_probabilities is None else self.fg_probabilities[index]
|
||||||
|
)
|
||||||
|
|
||||||
frame_data = FrameData(
|
frame_data = FrameData(
|
||||||
frame_number=index,
|
frame_number=index,
|
||||||
@ -63,6 +67,7 @@ class SingleSceneDataset(DatasetBase, Configurable):
|
|||||||
camera=pose,
|
camera=pose,
|
||||||
image_size_hw=torch.tensor(image.shape[1:]),
|
image_size_hw=torch.tensor(image.shape[1:]),
|
||||||
image_rgb=image,
|
image_rgb=image,
|
||||||
|
fg_probability=fg_probability,
|
||||||
frame_type=frame_type,
|
frame_type=frame_type,
|
||||||
)
|
)
|
||||||
return frame_data
|
return frame_data
|
||||||
@ -100,7 +105,11 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
|||||||
|
|
||||||
def _load_data(self) -> None:
|
def _load_data(self) -> None:
|
||||||
# This must be defined by each subclass,
|
# This must be defined by each subclass,
|
||||||
# and should set poses, images and i_split on self.
|
# and should set the following on self.
|
||||||
|
# - poses: a list of length-1 camera objects
|
||||||
|
# - images: [N, 3, H, W] tensor of rgb images - floats in [0,1]
|
||||||
|
# - fg_probabilities: None or [N, 1, H, W] of floats in [0,1]
|
||||||
|
# - splits: List[List[int]] of indices for train/val/test subsets.
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _get_dataset(
|
def _get_dataset(
|
||||||
@ -110,6 +119,12 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
|||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
split = self.i_split[split_idx]
|
split = self.i_split[split_idx]
|
||||||
frame_types = [frame_type] * len(split)
|
frame_types = [frame_type] * len(split)
|
||||||
|
fg_probabilities = (
|
||||||
|
None
|
||||||
|
# pyre-ignore[16]
|
||||||
|
if self.fg_probabilities is None
|
||||||
|
else self.fg_probabilities[split]
|
||||||
|
)
|
||||||
eval_batches = [[i] for i in range(len(split))]
|
eval_batches = [[i] for i in range(len(split))]
|
||||||
if split_idx != 0 and self.n_known_frames_for_test is not None:
|
if split_idx != 0 and self.n_known_frames_for_test is not None:
|
||||||
train_split = self.i_split[0]
|
train_split = self.i_split[0]
|
||||||
@ -130,6 +145,7 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
|||||||
object_name=self.object_name,
|
object_name=self.object_name,
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
images=self.images[split],
|
images=self.images[split],
|
||||||
|
fg_probabilities=fg_probabilities,
|
||||||
# pyre-ignore[16]
|
# pyre-ignore[16]
|
||||||
poses=[self.poses[i] for i in split],
|
poses=[self.poses[i] for i in split],
|
||||||
frame_types=frame_types,
|
frame_types=frame_types,
|
||||||
|
@ -41,6 +41,10 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
|
|||||||
# try getting a value
|
# try getting a value
|
||||||
value = dataset[0]
|
value = dataset[0]
|
||||||
self.assertEqual(value.image_rgb.shape, (3, 800, 800))
|
self.assertEqual(value.image_rgb.shape, (3, 800, 800))
|
||||||
|
self.assertEqual(value.fg_probability.shape, (1, 800, 800))
|
||||||
|
# corner of image is background
|
||||||
|
self.assertEqual(value.fg_probability[0, 0, 0], 0)
|
||||||
|
self.assertEqual(value.fg_probability.max(), 1.0)
|
||||||
self.assertIsInstance(value, FrameData)
|
self.assertIsInstance(value, FrameData)
|
||||||
|
|
||||||
def test_llff(self):
|
def test_llff(self):
|
||||||
@ -90,6 +94,7 @@ class TestDataLlff(TestCaseMixin, unittest.TestCase):
|
|||||||
for i, frame_type in enumerate(types):
|
for i, frame_type in enumerate(types):
|
||||||
value = dataset[i]
|
value = dataset[i]
|
||||||
self.assertEqual(value.frame_type, frame_type)
|
self.assertEqual(value.frame_type, frame_type)
|
||||||
|
self.assertIsNone(value.fg_probability)
|
||||||
|
|
||||||
self.assertEqual(len(dataset_map.test.get_eval_batches()), 3)
|
self.assertEqual(len(dataset_map.test.get_eval_batches()), 3)
|
||||||
for batch in dataset_map.test.get_eval_batches():
|
for batch in dataset_map.test.get_eval_batches():
|
||||||
|
Loading…
x
Reference in New Issue
Block a user