mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 06:10:34 +08:00
provide fg_probability for blender data
Summary: The blender synthetic dataset contains object masks in the alpha channel. Provide these in the corresponding dataset. Reviewed By: shapovalov Differential Revision: D37344380 fbshipit-source-id: 3ddacad9d667c0fa0ae5a61fb1d2ffc806c9abf3
This commit is contained in:
committed by
Facebook GitHub Bot
parent
731ea53c80
commit
3e4fb0b9d9
@@ -42,11 +42,13 @@ class BlenderDatasetMapProvider(SingleSceneDatasetMapProviderBase):
|
||||
)
|
||||
H, W, focal = hwf
|
||||
H, W = int(H), int(W)
|
||||
images = torch.from_numpy(images).permute(0, 3, 1, 2)[:, :3]
|
||||
images_masks = torch.from_numpy(images).permute(0, 3, 1, 2)
|
||||
|
||||
# pyre-ignore[16]
|
||||
self.poses = _interpret_blender_cameras(poses, H, W, focal)
|
||||
# pyre-ignore[16]
|
||||
self.images = images
|
||||
self.images = images_masks[:, :3]
|
||||
# pyre-ignore[16]
|
||||
self.fg_probabilities = images_masks[:, 3:4]
|
||||
# pyre-ignore[16]
|
||||
self.i_split = i_split
|
||||
|
||||
@@ -58,4 +58,6 @@ class LlffDatasetMapProvider(SingleSceneDatasetMapProviderBase):
|
||||
# pyre-ignore[16]
|
||||
self.images = images
|
||||
# pyre-ignore[16]
|
||||
self.fg_probabilities = None
|
||||
# pyre-ignore[16]
|
||||
self.i_split = i_split
|
||||
|
||||
@@ -38,6 +38,7 @@ class SingleSceneDataset(DatasetBase, Configurable):
|
||||
"""
|
||||
|
||||
images: List[torch.Tensor] = field()
|
||||
fg_probabilities: Optional[List[torch.Tensor]] = field()
|
||||
poses: List[PerspectiveCameras] = field()
|
||||
object_name: str = field()
|
||||
frame_types: List[str] = field()
|
||||
@@ -55,6 +56,9 @@ class SingleSceneDataset(DatasetBase, Configurable):
|
||||
image = self.images[index]
|
||||
pose = self.poses[index]
|
||||
frame_type = self.frame_types[index]
|
||||
fg_probability = (
|
||||
None if self.fg_probabilities is None else self.fg_probabilities[index]
|
||||
)
|
||||
|
||||
frame_data = FrameData(
|
||||
frame_number=index,
|
||||
@@ -63,6 +67,7 @@ class SingleSceneDataset(DatasetBase, Configurable):
|
||||
camera=pose,
|
||||
image_size_hw=torch.tensor(image.shape[1:]),
|
||||
image_rgb=image,
|
||||
fg_probability=fg_probability,
|
||||
frame_type=frame_type,
|
||||
)
|
||||
return frame_data
|
||||
@@ -100,7 +105,11 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
||||
|
||||
def _load_data(self) -> None:
|
||||
# This must be defined by each subclass,
|
||||
# and should set poses, images and i_split on self.
|
||||
# and should set the following on self.
|
||||
# - poses: a list of length-1 camera objects
|
||||
# - images: [N, 3, H, W] tensor of rgb images - floats in [0,1]
|
||||
# - fg_probabilities: None or [N, 1, H, W] of floats in [0,1]
|
||||
# - splits: List[List[int]] of indices for train/val/test subsets.
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_dataset(
|
||||
@@ -110,6 +119,12 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
||||
# pyre-ignore[16]
|
||||
split = self.i_split[split_idx]
|
||||
frame_types = [frame_type] * len(split)
|
||||
fg_probabilities = (
|
||||
None
|
||||
# pyre-ignore[16]
|
||||
if self.fg_probabilities is None
|
||||
else self.fg_probabilities[split]
|
||||
)
|
||||
eval_batches = [[i] for i in range(len(split))]
|
||||
if split_idx != 0 and self.n_known_frames_for_test is not None:
|
||||
train_split = self.i_split[0]
|
||||
@@ -130,6 +145,7 @@ class SingleSceneDatasetMapProviderBase(DatasetMapProviderBase):
|
||||
object_name=self.object_name,
|
||||
# pyre-ignore[16]
|
||||
images=self.images[split],
|
||||
fg_probabilities=fg_probabilities,
|
||||
# pyre-ignore[16]
|
||||
poses=[self.poses[i] for i in split],
|
||||
frame_types=frame_types,
|
||||
|
||||
Reference in New Issue
Block a user