In FrameDataBuilder, set all path even if we don’t load blobs

Summary:
This is a somewhat not BC change: some None paths will be replaced by metadata paths, even when they were not used for data loading.

Moreover, removing the legacy fix to the paths in the old CO3D release.

Reviewed By: bottler

Differential Revision: D69048238

fbshipit-source-id: 2a8b26d7b9f5e2adf39c65888b5863a5a9de1996
This commit is contained in:
Roman Shapovalov 2025-02-06 09:41:44 -08:00 committed by Facebook GitHub Bot
parent 43cd681d4f
commit 215590b497
2 changed files with 43 additions and 58 deletions

View File

@ -591,14 +591,36 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
),
)
dataset_root = self.dataset_root
mask_annotation = frame_annotation.mask
depth_annotation = frame_annotation.depth
image_path: str | None = None
mask_path: str | None = None
depth_path: str | None = None
pcl_path: str | None = None
if dataset_root is not None: # set all paths even if we wont load blobs
if frame_annotation.image.path is not None:
image_path = os.path.join(dataset_root, frame_annotation.image.path)
frame_data.image_path = image_path
if mask_annotation is not None and mask_annotation.path:
mask_path = os.path.join(dataset_root, mask_annotation.path)
frame_data.mask_path = mask_path
if depth_annotation is not None and depth_annotation.path is not None:
depth_path = os.path.join(dataset_root, depth_annotation.path)
frame_data.depth_path = depth_path
if point_cloud is not None:
pcl_path = os.path.join(dataset_root, point_cloud.path)
frame_data.sequence_point_cloud_path = pcl_path
fg_mask_np: np.ndarray | None = None
bbox_xywh: tuple[float, float, float, float] | None = None
mask_annotation = frame_annotation.mask
if mask_annotation is not None:
if load_blobs and self.load_masks:
fg_mask_np, mask_path = self._load_fg_probability(frame_annotation)
frame_data.mask_path = mask_path
if load_blobs and self.load_masks and mask_path:
fg_mask_np = self._load_fg_probability(frame_annotation, mask_path)
frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
bbox_xywh = mask_annotation.bounding_box_xywh
@ -608,11 +630,6 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
frame_data.image_size_hw = image_size_hw # original image size
# image size after crop/resize
frame_data.effective_image_size_hw = image_size_hw
image_path = None
dataset_root = self.dataset_root
if frame_annotation.image.path is not None and dataset_root is not None:
image_path = os.path.join(dataset_root, frame_annotation.image.path)
frame_data.image_path = image_path
if load_blobs and self.load_images:
if image_path is None:
@ -639,25 +656,16 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
bbox_xywh = get_bbox_from_mask(fg_mask_np, self.box_crop_mask_thr)
frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.float)
depth_annotation = frame_annotation.depth
if (
load_blobs
and self.load_depths
and depth_annotation is not None
and depth_annotation.path is not None
):
(
frame_data.depth_map,
frame_data.depth_path,
frame_data.depth_mask,
) = self._load_mask_depth(frame_annotation, fg_mask_np)
if load_blobs and self.load_depths and depth_path is not None:
frame_data.depth_map, frame_data.depth_mask = self._load_mask_depth(
frame_annotation, depth_path, fg_mask_np
)
if load_blobs and self.load_point_clouds and point_cloud is not None:
pcl_path = self._fix_point_cloud_path(point_cloud.path)
assert pcl_path is not None
frame_data.sequence_point_cloud = load_pointcloud(
self._local_path(pcl_path), max_points=self.max_points
)
frame_data.sequence_point_cloud_path = pcl_path
if frame_annotation.viewpoint is not None:
frame_data.camera = self._get_pytorch3d_camera(frame_annotation)
@ -673,17 +681,14 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
return frame_data
def _load_fg_probability(self, entry: FrameAnnotationT) -> Tuple[np.ndarray, str]:
mask_annotation = entry.mask
assert self.dataset_root is not None and mask_annotation is not None
full_path = os.path.join(self.dataset_root, mask_annotation.path)
fg_probability = load_mask(self._local_path(full_path))
def _load_fg_probability(self, entry: FrameAnnotationT, path: str) -> np.ndarray:
fg_probability = load_mask(self._local_path(path))
if fg_probability.shape[-2:] != entry.image.size:
raise ValueError(
f"bad mask size: {fg_probability.shape[-2:]} vs {entry.image.size}!"
)
return fg_probability, full_path
return fg_probability
def _postprocess_image(
self,
@ -705,13 +710,13 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
def _load_mask_depth(
self,
entry: FrameAnnotationT,
path: str,
fg_mask: Optional[np.ndarray],
) -> Tuple[torch.Tensor, str, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
entry_depth = entry.depth
dataset_root = self.dataset_root
assert dataset_root is not None
assert entry_depth is not None and entry_depth.path is not None
path = os.path.join(dataset_root, entry_depth.path)
assert entry_depth is not None
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
if self.mask_depths:
@ -725,7 +730,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
else:
depth_mask = (depth_map > 0.0).astype(np.float32)
return torch.tensor(depth_map), path, torch.tensor(depth_mask)
return torch.tensor(depth_map), torch.tensor(depth_mask)
def _get_pytorch3d_camera(
self,
@ -758,19 +763,6 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
T=torch.tensor(entry_viewpoint.T, dtype=torch.float)[None],
)
def _fix_point_cloud_path(self, path: str) -> str:
"""
Fix up a point cloud path from the dataset.
Some files in Co3Dv2 have an accidental absolute path stored.
"""
unwanted_prefix = (
"/large_experiments/p3/replay/datasets/co3d/co3d45k_220512/export_v23/"
)
if path.startswith(unwanted_prefix):
path = path[len(unwanted_prefix) :]
assert self.dataset_root is not None
return os.path.join(self.dataset_root, path)
def _local_path(self, path: str) -> str:
if self.path_manager is None:
return path

View File

@ -96,21 +96,15 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
# test that FrameDataBuilder works with get_default_args
get_default_args(FrameDataBuilder)
def test_fix_point_cloud_path(self):
"""Some files in Co3Dv2 have an accidental absolute path stored."""
original_path = "some_file_path"
modified_path = self.frame_data_builder._fix_point_cloud_path(original_path)
self.assertIn(original_path, modified_path)
self.assertIn(self.frame_data_builder.dataset_root, modified_path)
def test_load_and_adjust_frame_data(self):
self.frame_data.image_size_hw = safe_as_tensor(
self.frame_annotation.image.size, torch.long
)
self.frame_data.effective_image_size_hw = self.frame_data.image_size_hw
fg_mask_np, mask_path = self.frame_data_builder._load_fg_probability(
self.frame_annotation
mask_path = os.path.join(self.dataset_root, self.frame_annotation.mask.path)
fg_mask_np = self.frame_data_builder._load_fg_probability(
self.frame_annotation, mask_path
)
self.frame_data.mask_path = mask_path
self.frame_data.fg_probability = safe_as_tensor(fg_mask_np, torch.float)
@ -118,7 +112,6 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
bbox_xywh = get_bbox_from_mask(fg_mask_np, mask_thr)
self.frame_data.bbox_xywh = safe_as_tensor(bbox_xywh, torch.long)
self.assertIsNotNone(self.frame_data.mask_path)
self.assertTrue(torch.is_tensor(self.frame_data.fg_probability))
self.assertTrue(torch.is_tensor(self.frame_data.bbox_xywh))
# assert bboxes shape
@ -134,16 +127,16 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
)
self.assertIsInstance(self.frame_data.image_rgb, torch.Tensor)
depth_path = os.path.join(self.dataset_root, self.frame_annotation.depth.path)
(
self.frame_data.depth_map,
depth_path,
self.frame_data.depth_mask,
) = self.frame_data_builder._load_mask_depth(
self.frame_annotation,
depth_path,
self.frame_data.fg_probability,
)
self.assertTrue(torch.is_tensor(self.frame_data.depth_map))
self.assertIsNotNone(depth_path)
self.assertTrue(torch.is_tensor(self.frame_data.depth_mask))
new_size = (self.image_height, self.image_width)