mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Refactor: FrameDataBuilder is more extensible.
Summary: This is mostly a refactoring diff to reduce friction in extending the frame data. Slight functional changes: dataset getitem now accepts (seq_name, frame_number_as_singleton_tensor) as a non-advertised feature. Otherwise this code crashes: ``` item = dataset[0] dataset[item.sequence_name, item.frame_number] ``` Reviewed By: bottler Differential Revision: D45780175 fbshipit-source-id: 75b8e8d3dabed954a804310abdbd8ab44a8dea29
This commit is contained in:
		
							parent
							
								
									d08fe6d45a
								
							
						
					
					
						commit
						b0462598ac
					
				@ -132,6 +132,11 @@ class TestExperiment(unittest.TestCase):
 | 
			
		||||
        # Check that the default config values, defined by Experiment and its
 | 
			
		||||
        # members, is what we expect it to be.
 | 
			
		||||
        cfg = OmegaConf.structured(experiment.Experiment)
 | 
			
		||||
        # the following removes the possible effect of env variables
 | 
			
		||||
        ds_arg = cfg.data_source_ImplicitronDataSource_args
 | 
			
		||||
        ds_arg.dataset_map_provider_JsonIndexDatasetMapProvider_args.dataset_root = ""
 | 
			
		||||
        ds_arg.dataset_map_provider_JsonIndexDatasetMapProviderV2_args.dataset_root = ""
 | 
			
		||||
        cfg.training_loop_ImplicitronTrainingLoop_args.visdom_port = 8097
 | 
			
		||||
        yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
 | 
			
		||||
        if DEBUG:
 | 
			
		||||
            (DATA_DIR / "experiment.yaml").write_text(yaml)
 | 
			
		||||
 | 
			
		||||
@ -203,7 +203,10 @@ class FrameData(Mapping[str, Any]):
 | 
			
		||||
                when no image has been loaded)
 | 
			
		||||
        """
 | 
			
		||||
        if self.bbox_xywh is None:
 | 
			
		||||
            raise ValueError("Attempted cropping by metadata with empty bounding box")
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "Attempted cropping by metadata with empty bounding box. Consider either"
 | 
			
		||||
                " to remove_empty_masks or turn off box_crop in the dataset config."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if not self._uncropped:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
@ -528,12 +531,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
 | 
			
		||||
                "Make sure it is set in either FrameDataBuilder or Dataset params."
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.path_manager is None:
 | 
			
		||||
            dataset_root_exists = os.path.isdir(self.dataset_root)  # pyre-ignore
 | 
			
		||||
        else:
 | 
			
		||||
            dataset_root_exists = self.path_manager.isdir(self.dataset_root)
 | 
			
		||||
 | 
			
		||||
        if load_any_blob and not dataset_root_exists:
 | 
			
		||||
        if load_any_blob and not self._exists_in_dataset_root(""):
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"dataset_root is passed but {self.dataset_root} does not exist."
 | 
			
		||||
            )
 | 
			
		||||
@ -604,14 +602,27 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
 | 
			
		||||
            frame_data.image_size_hw = image_size_hw  # original image size
 | 
			
		||||
            # image size after crop/resize
 | 
			
		||||
            frame_data.effective_image_size_hw = image_size_hw
 | 
			
		||||
            image_path = None
 | 
			
		||||
            dataset_root = self.dataset_root
 | 
			
		||||
            if frame_annotation.image.path is not None and dataset_root is not None:
 | 
			
		||||
                image_path = os.path.join(dataset_root, frame_annotation.image.path)
 | 
			
		||||
                frame_data.image_path = image_path
 | 
			
		||||
 | 
			
		||||
            if load_blobs and self.load_images:
 | 
			
		||||
                (
 | 
			
		||||
                    frame_data.image_rgb,
 | 
			
		||||
                    frame_data.image_path,
 | 
			
		||||
                ) = self._load_images(frame_annotation, frame_data.fg_probability)
 | 
			
		||||
                if image_path is None:
 | 
			
		||||
                    raise ValueError("Image path is required to load images.")
 | 
			
		||||
 | 
			
		||||
        if load_blobs and self.load_depths and frame_annotation.depth is not None:
 | 
			
		||||
                image_np = load_image(self._local_path(image_path))
 | 
			
		||||
                frame_data.image_rgb = self._postprocess_image(
 | 
			
		||||
                    image_np, frame_annotation.image.size, frame_data.fg_probability
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        if (
 | 
			
		||||
            load_blobs
 | 
			
		||||
            and self.load_depths
 | 
			
		||||
            and frame_annotation.depth is not None
 | 
			
		||||
            and frame_annotation.depth.path is not None
 | 
			
		||||
        ):
 | 
			
		||||
            (
 | 
			
		||||
                frame_data.depth_map,
 | 
			
		||||
                frame_data.depth_path,
 | 
			
		||||
@ -652,25 +663,22 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
 | 
			
		||||
 | 
			
		||||
        return fg_probability, full_path
 | 
			
		||||
 | 
			
		||||
    def _load_images(
 | 
			
		||||
    def _postprocess_image(
 | 
			
		||||
        self,
 | 
			
		||||
        entry: types.FrameAnnotation,
 | 
			
		||||
        image_np: np.ndarray,
 | 
			
		||||
        image_size: Tuple[int, int],
 | 
			
		||||
        fg_probability: Optional[torch.Tensor],
 | 
			
		||||
    ) -> Tuple[torch.Tensor, str]:
 | 
			
		||||
        assert self.dataset_root is not None and entry.image is not None
 | 
			
		||||
        path = os.path.join(self.dataset_root, entry.image.path)
 | 
			
		||||
        image_rgb = load_image(self._local_path(path))
 | 
			
		||||
    ) -> torch.Tensor:
 | 
			
		||||
        image_rgb = safe_as_tensor(image_np, torch.float)
 | 
			
		||||
 | 
			
		||||
        if image_rgb.shape[-2:] != entry.image.size:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
 | 
			
		||||
            )
 | 
			
		||||
        if image_rgb.shape[-2:] != image_size:
 | 
			
		||||
            raise ValueError(f"bad image size: {image_rgb.shape[-2:]} vs {image_size}!")
 | 
			
		||||
 | 
			
		||||
        if self.mask_images:
 | 
			
		||||
            assert fg_probability is not None
 | 
			
		||||
            image_rgb *= fg_probability
 | 
			
		||||
 | 
			
		||||
        return image_rgb, path
 | 
			
		||||
        return image_rgb
 | 
			
		||||
 | 
			
		||||
    def _load_mask_depth(
 | 
			
		||||
        self,
 | 
			
		||||
@ -678,18 +686,19 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
 | 
			
		||||
        fg_probability: Optional[torch.Tensor],
 | 
			
		||||
    ) -> Tuple[torch.Tensor, str, torch.Tensor]:
 | 
			
		||||
        entry_depth = entry.depth
 | 
			
		||||
        assert self.dataset_root is not None and entry_depth is not None
 | 
			
		||||
        path = os.path.join(self.dataset_root, entry_depth.path)
 | 
			
		||||
        dataset_root = self.dataset_root
 | 
			
		||||
        assert dataset_root is not None
 | 
			
		||||
        assert entry_depth is not None and entry_depth.path is not None
 | 
			
		||||
        path = os.path.join(dataset_root, entry_depth.path)
 | 
			
		||||
        depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
 | 
			
		||||
 | 
			
		||||
        if self.mask_depths:
 | 
			
		||||
            assert fg_probability is not None
 | 
			
		||||
            depth_map *= fg_probability
 | 
			
		||||
 | 
			
		||||
        if self.load_depth_masks:
 | 
			
		||||
            assert entry_depth.mask_path is not None
 | 
			
		||||
            # pyre-ignore
 | 
			
		||||
            mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
 | 
			
		||||
        mask_path = entry_depth.mask_path
 | 
			
		||||
        if self.load_depth_masks and mask_path is not None:
 | 
			
		||||
            mask_path = os.path.join(dataset_root, mask_path)
 | 
			
		||||
            depth_mask = load_depth_mask(self._local_path(mask_path))
 | 
			
		||||
        else:
 | 
			
		||||
            depth_mask = torch.ones_like(depth_map)
 | 
			
		||||
@ -745,6 +754,16 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
 | 
			
		||||
            return path
 | 
			
		||||
        return self.path_manager.get_local_path(path)
 | 
			
		||||
 | 
			
		||||
    def _exists_in_dataset_root(self, relpath) -> bool:
 | 
			
		||||
        if not self.dataset_root:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        full_path = os.path.join(self.dataset_root, relpath)
 | 
			
		||||
        if self.path_manager is None:
 | 
			
		||||
            return os.path.exists(full_path)
 | 
			
		||||
        else:
 | 
			
		||||
            return self.path_manager.exists(full_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@registry.register
 | 
			
		||||
class FrameDataBuilder(GenericWorkaround, GenericFrameDataBuilder[FrameData]):
 | 
			
		||||
 | 
			
		||||
@ -210,6 +210,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase):  # pyre-ignore
 | 
			
		||||
            seq, frame = self._index.index[frame_idx]
 | 
			
		||||
        else:
 | 
			
		||||
            seq, frame, *rest = frame_idx
 | 
			
		||||
            if isinstance(frame, torch.LongTensor):
 | 
			
		||||
                frame = frame.item()
 | 
			
		||||
 | 
			
		||||
            if (seq, frame) not in self._index.index:
 | 
			
		||||
                raise IndexError(
 | 
			
		||||
                    f"Sequence-frame index {frame_idx} not found; was it filtered out?"
 | 
			
		||||
 | 
			
		||||
@ -225,19 +225,23 @@ def resize_image(
 | 
			
		||||
    return imre_, minscale, mask
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
 | 
			
		||||
    im = np.atleast_3d(image).transpose((2, 0, 1))
 | 
			
		||||
    return im.astype(np.float32) / 255.0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_image(path: str) -> np.ndarray:
 | 
			
		||||
    with Image.open(path) as pil_im:
 | 
			
		||||
        im = np.array(pil_im.convert("RGB"))
 | 
			
		||||
    im = im.transpose((2, 0, 1))
 | 
			
		||||
    im = im.astype(np.float32) / 255.0
 | 
			
		||||
    return im
 | 
			
		||||
 | 
			
		||||
    return transpose_normalize_image(im)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_mask(path: str) -> np.ndarray:
 | 
			
		||||
    with Image.open(path) as pil_im:
 | 
			
		||||
        mask = np.array(pil_im)
 | 
			
		||||
    mask = mask.astype(np.float32) / 255.0
 | 
			
		||||
    return mask[None]  # fake feature channel
 | 
			
		||||
 | 
			
		||||
    return transpose_normalize_image(mask)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_depth(path: str, scale_adjustment: float) -> np.ndarray:
 | 
			
		||||
 | 
			
		||||
@ -25,6 +25,7 @@ from pytorch3d.implicitron.dataset.utils import (
 | 
			
		||||
    load_image,
 | 
			
		||||
    load_mask,
 | 
			
		||||
    safe_as_tensor,
 | 
			
		||||
    transpose_normalize_image,
 | 
			
		||||
)
 | 
			
		||||
from pytorch3d.implicitron.tools.config import get_default_args
 | 
			
		||||
from pytorch3d.renderer.cameras import PerspectiveCameras
 | 
			
		||||
@ -123,14 +124,15 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        # assert bboxes shape
 | 
			
		||||
        self.assertEqual(self.frame_data.bbox_xywh.shape, torch.Size([4]))
 | 
			
		||||
 | 
			
		||||
        (
 | 
			
		||||
            self.frame_data.image_rgb,
 | 
			
		||||
            self.frame_data.image_path,
 | 
			
		||||
        ) = self.frame_data_builder._load_images(
 | 
			
		||||
            self.frame_annotation, self.frame_data.fg_probability
 | 
			
		||||
        image_path = os.path.join(
 | 
			
		||||
            self.frame_data_builder.dataset_root, self.frame_annotation.image.path
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(type(self.frame_data.image_rgb), np.ndarray)
 | 
			
		||||
        self.assertIsNotNone(self.frame_data.image_path)
 | 
			
		||||
        image_np = load_image(self.frame_data_builder._local_path(image_path))
 | 
			
		||||
        self.assertIsInstance(image_np, np.ndarray)
 | 
			
		||||
        self.frame_data.image_rgb = self.frame_data_builder._postprocess_image(
 | 
			
		||||
            image_np, self.frame_annotation.image.size, self.frame_data.fg_probability
 | 
			
		||||
        )
 | 
			
		||||
        self.assertIsInstance(self.frame_data.image_rgb, torch.Tensor)
 | 
			
		||||
 | 
			
		||||
        (
 | 
			
		||||
            self.frame_data.depth_map,
 | 
			
		||||
@ -184,6 +186,34 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
 | 
			
		||||
        )
 | 
			
		||||
        self.assertEqual(type(self.frame_data.camera), PerspectiveCameras)
 | 
			
		||||
 | 
			
		||||
    def test_transpose_normalize_image(self):
 | 
			
		||||
        def inverse_transpose_normalize_image(image: np.ndarray) -> np.ndarray:
 | 
			
		||||
            im = image * 255.0
 | 
			
		||||
            return im.transpose((1, 2, 0)).astype(np.uint8)
 | 
			
		||||
 | 
			
		||||
        # Test 2D input
 | 
			
		||||
        input_image = np.array(
 | 
			
		||||
            [[10, 20, 30], [40, 50, 60], [70, 80, 90]], dtype=np.uint8
 | 
			
		||||
        )
 | 
			
		||||
        expected_input = inverse_transpose_normalize_image(
 | 
			
		||||
            transpose_normalize_image(input_image)
 | 
			
		||||
        )
 | 
			
		||||
        self.assertClose(input_image[..., None], expected_input)
 | 
			
		||||
 | 
			
		||||
        # Test 3D input
 | 
			
		||||
        input_image = np.array(
 | 
			
		||||
            [
 | 
			
		||||
                [[10, 20, 30], [40, 50, 60], [70, 80, 90]],
 | 
			
		||||
                [[100, 110, 120], [130, 140, 150], [160, 170, 180]],
 | 
			
		||||
                [[190, 200, 210], [220, 230, 240], [250, 255, 255]],
 | 
			
		||||
            ],
 | 
			
		||||
            dtype=np.uint8,
 | 
			
		||||
        )
 | 
			
		||||
        expected_input = inverse_transpose_normalize_image(
 | 
			
		||||
            transpose_normalize_image(input_image)
 | 
			
		||||
        )
 | 
			
		||||
        self.assertClose(input_image, expected_input)
 | 
			
		||||
 | 
			
		||||
    def test_load_image(self):
 | 
			
		||||
        path = os.path.join(self.dataset_root, self.frame_annotation.image.path)
 | 
			
		||||
        local_path = self.path_manager.get_local_path(path)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user