From b0462598ac59ba01c9866a3268f9c0da44d62456 Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Wed, 17 May 2023 10:38:34 -0700 Subject: [PATCH] Refactor: FrameDataBuilder is more extensible. Summary: This is mostly a refactoring diff to reduce friction in extending the frame data. Slight functional changes: dataset getitem now accepts (seq_name, frame_number_as_singleton_tensor) as a non-advertised feature. Otherwise this code crashes: ``` item = dataset[0] dataset[item.sequence_name, item.frame_number] ``` Reviewed By: bottler Differential Revision: D45780175 fbshipit-source-id: 75b8e8d3dabed954a804310abdbd8ab44a8dea29 --- .../tests/test_experiment.py | 5 ++ pytorch3d/implicitron/dataset/frame_data.py | 77 ++++++++++++------- pytorch3d/implicitron/dataset/sql_dataset.py | 3 + pytorch3d/implicitron/dataset/utils.py | 14 ++-- tests/implicitron/test_frame_data_builder.py | 44 +++++++++-- 5 files changed, 102 insertions(+), 41 deletions(-) diff --git a/projects/implicitron_trainer/tests/test_experiment.py b/projects/implicitron_trainer/tests/test_experiment.py index a07eb8be..590102fa 100644 --- a/projects/implicitron_trainer/tests/test_experiment.py +++ b/projects/implicitron_trainer/tests/test_experiment.py @@ -132,6 +132,11 @@ class TestExperiment(unittest.TestCase): # Check that the default config values, defined by Experiment and its # members, is what we expect it to be. cfg = OmegaConf.structured(experiment.Experiment) + # the following removes the possible effect of env variables + ds_arg = cfg.data_source_ImplicitronDataSource_args + ds_arg.dataset_map_provider_JsonIndexDatasetMapProvider_args.dataset_root = "" + ds_arg.dataset_map_provider_JsonIndexDatasetMapProviderV2_args.dataset_root = "" + cfg.training_loop_ImplicitronTrainingLoop_args.visdom_port = 8097 yaml = OmegaConf.to_yaml(cfg, sort_keys=False) if DEBUG: (DATA_DIR / "experiment.yaml").write_text(yaml) diff --git a/pytorch3d/implicitron/dataset/frame_data.py b/pytorch3d/implicitron/dataset/frame_data.py index 9fff7195..5a8100b8 100644 --- a/pytorch3d/implicitron/dataset/frame_data.py +++ b/pytorch3d/implicitron/dataset/frame_data.py @@ -203,7 +203,10 @@ class FrameData(Mapping[str, Any]): when no image has been loaded) """ if self.bbox_xywh is None: - raise ValueError("Attempted cropping by metadata with empty bounding box") + raise ValueError( + "Attempted cropping by metadata with empty bounding box. Consider either" + " to remove_empty_masks or turn off box_crop in the dataset config." + ) if not self._uncropped: raise ValueError( @@ -528,12 +531,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): "Make sure it is set in either FrameDataBuilder or Dataset params." ) - if self.path_manager is None: - dataset_root_exists = os.path.isdir(self.dataset_root) # pyre-ignore - else: - dataset_root_exists = self.path_manager.isdir(self.dataset_root) - - if load_any_blob and not dataset_root_exists: + if load_any_blob and not self._exists_in_dataset_root(""): raise ValueError( f"dataset_root is passed but {self.dataset_root} does not exist." ) @@ -604,14 +602,27 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): frame_data.image_size_hw = image_size_hw # original image size # image size after crop/resize frame_data.effective_image_size_hw = image_size_hw + image_path = None + dataset_root = self.dataset_root + if frame_annotation.image.path is not None and dataset_root is not None: + image_path = os.path.join(dataset_root, frame_annotation.image.path) + frame_data.image_path = image_path if load_blobs and self.load_images: - ( - frame_data.image_rgb, - frame_data.image_path, - ) = self._load_images(frame_annotation, frame_data.fg_probability) + if image_path is None: + raise ValueError("Image path is required to load images.") - if load_blobs and self.load_depths and frame_annotation.depth is not None: + image_np = load_image(self._local_path(image_path)) + frame_data.image_rgb = self._postprocess_image( + image_np, frame_annotation.image.size, frame_data.fg_probability + ) + + if ( + load_blobs + and self.load_depths + and frame_annotation.depth is not None + and frame_annotation.depth.path is not None + ): ( frame_data.depth_map, frame_data.depth_path, @@ -652,25 +663,22 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): return fg_probability, full_path - def _load_images( + def _postprocess_image( self, - entry: types.FrameAnnotation, + image_np: np.ndarray, + image_size: Tuple[int, int], fg_probability: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, str]: - assert self.dataset_root is not None and entry.image is not None - path = os.path.join(self.dataset_root, entry.image.path) - image_rgb = load_image(self._local_path(path)) + ) -> torch.Tensor: + image_rgb = safe_as_tensor(image_np, torch.float) - if image_rgb.shape[-2:] != entry.image.size: - raise ValueError( - f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!" - ) + if image_rgb.shape[-2:] != image_size: + raise ValueError(f"bad image size: {image_rgb.shape[-2:]} vs {image_size}!") if self.mask_images: assert fg_probability is not None image_rgb *= fg_probability - return image_rgb, path + return image_rgb def _load_mask_depth( self, @@ -678,18 +686,19 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): fg_probability: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, str, torch.Tensor]: entry_depth = entry.depth - assert self.dataset_root is not None and entry_depth is not None - path = os.path.join(self.dataset_root, entry_depth.path) + dataset_root = self.dataset_root + assert dataset_root is not None + assert entry_depth is not None and entry_depth.path is not None + path = os.path.join(dataset_root, entry_depth.path) depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment) if self.mask_depths: assert fg_probability is not None depth_map *= fg_probability - if self.load_depth_masks: - assert entry_depth.mask_path is not None - # pyre-ignore - mask_path = os.path.join(self.dataset_root, entry_depth.mask_path) + mask_path = entry_depth.mask_path + if self.load_depth_masks and mask_path is not None: + mask_path = os.path.join(dataset_root, mask_path) depth_mask = load_depth_mask(self._local_path(mask_path)) else: depth_mask = torch.ones_like(depth_map) @@ -745,6 +754,16 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): return path return self.path_manager.get_local_path(path) + def _exists_in_dataset_root(self, relpath) -> bool: + if not self.dataset_root: + return False + + full_path = os.path.join(self.dataset_root, relpath) + if self.path_manager is None: + return os.path.exists(full_path) + else: + return self.path_manager.exists(full_path) + @registry.register class FrameDataBuilder(GenericWorkaround, GenericFrameDataBuilder[FrameData]): diff --git a/pytorch3d/implicitron/dataset/sql_dataset.py b/pytorch3d/implicitron/dataset/sql_dataset.py index 1eb26964..4c9d3bb5 100644 --- a/pytorch3d/implicitron/dataset/sql_dataset.py +++ b/pytorch3d/implicitron/dataset/sql_dataset.py @@ -210,6 +210,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore seq, frame = self._index.index[frame_idx] else: seq, frame, *rest = frame_idx + if isinstance(frame, torch.LongTensor): + frame = frame.item() + if (seq, frame) not in self._index.index: raise IndexError( f"Sequence-frame index {frame_idx} not found; was it filtered out?" diff --git a/pytorch3d/implicitron/dataset/utils.py b/pytorch3d/implicitron/dataset/utils.py index 0982fbc0..2c31a174 100644 --- a/pytorch3d/implicitron/dataset/utils.py +++ b/pytorch3d/implicitron/dataset/utils.py @@ -225,19 +225,23 @@ def resize_image( return imre_, minscale, mask +def transpose_normalize_image(image: np.ndarray) -> np.ndarray: + im = np.atleast_3d(image).transpose((2, 0, 1)) + return im.astype(np.float32) / 255.0 + + def load_image(path: str) -> np.ndarray: with Image.open(path) as pil_im: im = np.array(pil_im.convert("RGB")) - im = im.transpose((2, 0, 1)) - im = im.astype(np.float32) / 255.0 - return im + + return transpose_normalize_image(im) def load_mask(path: str) -> np.ndarray: with Image.open(path) as pil_im: mask = np.array(pil_im) - mask = mask.astype(np.float32) / 255.0 - return mask[None] # fake feature channel + + return transpose_normalize_image(mask) def load_depth(path: str, scale_adjustment: float) -> np.ndarray: diff --git a/tests/implicitron/test_frame_data_builder.py b/tests/implicitron/test_frame_data_builder.py index e66d67df..28b5e046 100644 --- a/tests/implicitron/test_frame_data_builder.py +++ b/tests/implicitron/test_frame_data_builder.py @@ -25,6 +25,7 @@ from pytorch3d.implicitron.dataset.utils import ( load_image, load_mask, safe_as_tensor, + transpose_normalize_image, ) from pytorch3d.implicitron.tools.config import get_default_args from pytorch3d.renderer.cameras import PerspectiveCameras @@ -123,14 +124,15 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase): # assert bboxes shape self.assertEqual(self.frame_data.bbox_xywh.shape, torch.Size([4])) - ( - self.frame_data.image_rgb, - self.frame_data.image_path, - ) = self.frame_data_builder._load_images( - self.frame_annotation, self.frame_data.fg_probability + image_path = os.path.join( + self.frame_data_builder.dataset_root, self.frame_annotation.image.path ) - self.assertEqual(type(self.frame_data.image_rgb), np.ndarray) - self.assertIsNotNone(self.frame_data.image_path) + image_np = load_image(self.frame_data_builder._local_path(image_path)) + self.assertIsInstance(image_np, np.ndarray) + self.frame_data.image_rgb = self.frame_data_builder._postprocess_image( + image_np, self.frame_annotation.image.size, self.frame_data.fg_probability + ) + self.assertIsInstance(self.frame_data.image_rgb, torch.Tensor) ( self.frame_data.depth_map, @@ -184,6 +186,34 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase): ) self.assertEqual(type(self.frame_data.camera), PerspectiveCameras) + def test_transpose_normalize_image(self): + def inverse_transpose_normalize_image(image: np.ndarray) -> np.ndarray: + im = image * 255.0 + return im.transpose((1, 2, 0)).astype(np.uint8) + + # Test 2D input + input_image = np.array( + [[10, 20, 30], [40, 50, 60], [70, 80, 90]], dtype=np.uint8 + ) + expected_input = inverse_transpose_normalize_image( + transpose_normalize_image(input_image) + ) + self.assertClose(input_image[..., None], expected_input) + + # Test 3D input + input_image = np.array( + [ + [[10, 20, 30], [40, 50, 60], [70, 80, 90]], + [[100, 110, 120], [130, 140, 150], [160, 170, 180]], + [[190, 200, 210], [220, 230, 240], [250, 255, 255]], + ], + dtype=np.uint8, + ) + expected_input = inverse_transpose_normalize_image( + transpose_normalize_image(input_image) + ) + self.assertClose(input_image, expected_input) + def test_load_image(self): path = os.path.join(self.dataset_root, self.frame_annotation.image.path) local_path = self.path_manager.get_local_path(path)