diff --git a/pytorch3d/implicitron/dataset/frame_data.py b/pytorch3d/implicitron/dataset/frame_data.py index ed88c0f8..5f547905 100644 --- a/pytorch3d/implicitron/dataset/frame_data.py +++ b/pytorch3d/implicitron/dataset/frame_data.py @@ -591,14 +591,36 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): ), ) + dataset_root = self.dataset_root + mask_annotation = frame_annotation.mask + depth_annotation = frame_annotation.depth + image_path: str | None = None + mask_path: str | None = None + depth_path: str | None = None + pcl_path: str | None = None + if dataset_root is not None: # set all paths even if we won’t load blobs + if frame_annotation.image.path is not None: + image_path = os.path.join(dataset_root, frame_annotation.image.path) + frame_data.image_path = image_path + + if mask_annotation is not None and mask_annotation.path: + mask_path = os.path.join(dataset_root, mask_annotation.path) + frame_data.mask_path = mask_path + + if depth_annotation is not None and depth_annotation.path is not None: + depth_path = os.path.join(dataset_root, depth_annotation.path) + frame_data.depth_path = depth_path + + if point_cloud is not None: + pcl_path = os.path.join(dataset_root, point_cloud.path) + frame_data.sequence_point_cloud_path = pcl_path + fg_mask_np: np.ndarray | None = None bbox_xywh: tuple[float, float, float, float] | None = None - mask_annotation = frame_annotation.mask if mask_annotation is not None: - if load_blobs and self.load_masks: - fg_mask_np, mask_path = self._load_fg_probability(frame_annotation) - frame_data.mask_path = mask_path + if load_blobs and self.load_masks and mask_path: + fg_mask_np = self._load_fg_probability(frame_annotation, mask_path) frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float) bbox_xywh = mask_annotation.bounding_box_xywh @@ -608,11 +630,6 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): frame_data.image_size_hw = image_size_hw # original image size # image size after crop/resize frame_data.effective_image_size_hw = image_size_hw - image_path = None - dataset_root = self.dataset_root - if frame_annotation.image.path is not None and dataset_root is not None: - image_path = os.path.join(dataset_root, frame_annotation.image.path) - frame_data.image_path = image_path if load_blobs and self.load_images: if image_path is None: @@ -639,25 +656,16 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr) frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float) - depth_annotation = frame_annotation.depth - if ( - load_blobs - and self.load_depths - and depth_annotation is not None - and depth_annotation.path is not None - ): - ( - frame_data.depth_map, - frame_data.depth_path, - frame_data.depth_mask, - ) = self._load_mask_depth(frame_annotation, fg_mask_np) + if load_blobs and self.load_depths and depth_path is not None: + frame_data.depth_map, frame_data.depth_mask = self._load_mask_depth( + frame_annotation, depth_path, fg_mask_np + ) if load_blobs and self.load_point_clouds and point_cloud is not None: - pcl_path = self._fix_point_cloud_path(point_cloud.path) + assert pcl_path is not None frame_data.sequence_point_cloud = load_pointcloud( self._local_path(pcl_path), max_points=self.max_points ) - frame_data.sequence_point_cloud_path = pcl_path if frame_annotation.viewpoint is not None: frame_data.camera = self._get_pytorch3d_camera(frame_annotation) @@ -673,17 +681,14 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): return frame_data - def _load_fg_probability(self, entry: FrameAnnotationT) -> Tuple[np.ndarray, str]: - mask_annotation = entry.mask - assert self.dataset_root is not None and mask_annotation is not None - full_path = os.path.join(self.dataset_root, mask_annotation.path) - fg_probability = load_mask(self._local_path(full_path)) + def _load_fg_probability(self, entry: FrameAnnotationT, path: str) -> np.ndarray: + fg_probability = load_mask(self._local_path(path)) if fg_probability.shape[-2:] != entry.image.size: raise ValueError( f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!" ) - return fg_probability, full_path + return fg_probability def _postprocess_image( self, @@ -705,13 +710,13 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): def _load_mask_depth( self, entry: FrameAnnotationT, + path: str, fg_mask: Optional[np.ndarray], - ) -> Tuple[torch.Tensor, str, torch.Tensor]: + ) -> tuple[torch.Tensor, torch.Tensor]: entry_depth = entry.depth dataset_root = self.dataset_root assert dataset_root is not None - assert entry_depth is not None and entry_depth.path is not None - path = os.path.join(dataset_root, entry_depth.path) + assert entry_depth is not None depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment) if self.mask_depths: @@ -725,7 +730,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): else: depth_mask = (depth_map > 0.0).astype(np.float32) - return torch.tensor(depth_map), path, torch.tensor(depth_mask) + return torch.tensor(depth_map), torch.tensor(depth_mask) def _get_pytorch3d_camera( self, @@ -758,19 +763,6 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC): T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None], ) - def _fix_point_cloud_path(self, path: str) -> str: - """ - Fix up a point cloud path from the dataset. - Some files in Co3Dv2 have an accidental absolute path stored. - """ - unwanted_prefix = ( - "/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/" - ) - if path.startswith(unwanted_prefix): - path = path[len(unwanted_prefix) :] - assert self.dataset_root is not None - return os.path.join(self.dataset_root, path) - def _local_path(self, path: str) -> str: if self.path_manager is None: return path diff --git a/tests/implicitron/test_frame_data_builder.py b/tests/implicitron/test_frame_data_builder.py index 28810455..611b0b94 100644 --- a/tests/implicitron/test_frame_data_builder.py +++ b/tests/implicitron/test_frame_data_builder.py @@ -96,21 +96,15 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase): # test that FrameDataBuilder works with get_default_args get_default_args(FrameDataBuilder) - def test_fix_point_cloud_path(self): - """Some files in Co3Dv2 have an accidental absolute path stored.""" - original_path = "some_file_path" - modified_path = self.frame_data_builder._fix_point_cloud_path(original_path) - self.assertIn(original_path, modified_path) - self.assertIn(self.frame_data_builder.dataset_root, modified_path) - def test_load_and_adjust_frame_data(self): self.frame_data.image_size_hw = safe_as_tensor( self.frame_annotation.image.size, torch.long ) self.frame_data.effective_image_size_hw = self.frame_data.image_size_hw - fg_mask_np, mask_path = self.frame_data_builder._load_fg_probability( - self.frame_annotation + mask_path = os.path.join(self.dataset_root, self.frame_annotation.mask.path) + fg_mask_np = self.frame_data_builder._load_fg_probability( + self.frame_annotation, mask_path ) self.frame_data.mask_path = mask_path self.frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float) @@ -118,7 +112,6 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase): bbox_xywh = get_bbox_from_mask(fg_mask_np, mask_thr) self.frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long) - self.assertIsNotNone(self.frame_data.mask_path) self.assertTrue(torch.is_tensor(self.frame_data.fg_probability)) self.assertTrue(torch.is_tensor(self.frame_data.bbox_xywh)) # assert bboxes shape @@ -134,16 +127,16 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase): ) self.assertIsInstance(self.frame_data.image_rgb, torch.Tensor) + depth_path = os.path.join(self.dataset_root, self.frame_annotation.depth.path) ( self.frame_data.depth_map, - depth_path, self.frame_data.depth_mask, ) = self.frame_data_builder._load_mask_depth( self.frame_annotation, + depth_path, self.frame_data.fg_probability, ) self.assertTrue(torch.is_tensor(self.frame_data.depth_map)) - self.assertIsNotNone(depth_path) self.assertTrue(torch.is_tensor(self.frame_data.depth_mask)) new_size = (self.image_height, self.image_width)