mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2026-01-17 11:50:35 +08:00
Refactor: FrameDataBuilder is more extensible.
Summary: This is mostly a refactoring diff to reduce friction in extending the frame data. Slight functional changes: dataset getitem now accepts (seq_name, frame_number_as_singleton_tensor) as a non-advertised feature. Otherwise this code crashes: ``` item = dataset[0] dataset[item.sequence_name, item.frame_number] ``` Reviewed By: bottler Differential Revision: D45780175 fbshipit-source-id: 75b8e8d3dabed954a804310abdbd8ab44a8dea29
This commit is contained in:
committed by
Facebook GitHub Bot
parent
d08fe6d45a
commit
b0462598ac
@@ -203,7 +203,10 @@ class FrameData(Mapping[str, Any]):
|
||||
when no image has been loaded)
|
||||
"""
|
||||
if self.bbox_xywh is None:
|
||||
raise ValueError("Attempted cropping by metadata with empty bounding box")
|
||||
raise ValueError(
|
||||
"Attempted cropping by metadata with empty bounding box. Consider either"
|
||||
" to remove_empty_masks or turn off box_crop in the dataset config."
|
||||
)
|
||||
|
||||
if not self._uncropped:
|
||||
raise ValueError(
|
||||
@@ -528,12 +531,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
"Make sure it is set in either FrameDataBuilder or Dataset params."
|
||||
)
|
||||
|
||||
if self.path_manager is None:
|
||||
dataset_root_exists = os.path.isdir(self.dataset_root) # pyre-ignore
|
||||
else:
|
||||
dataset_root_exists = self.path_manager.isdir(self.dataset_root)
|
||||
|
||||
if load_any_blob and not dataset_root_exists:
|
||||
if load_any_blob and not self._exists_in_dataset_root(""):
|
||||
raise ValueError(
|
||||
f"dataset_root is passed but {self.dataset_root} does not exist."
|
||||
)
|
||||
@@ -604,14 +602,27 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
frame_data.image_size_hw = image_size_hw # original image size
|
||||
# image size after crop/resize
|
||||
frame_data.effective_image_size_hw = image_size_hw
|
||||
image_path = None
|
||||
dataset_root = self.dataset_root
|
||||
if frame_annotation.image.path is not None and dataset_root is not None:
|
||||
image_path = os.path.join(dataset_root, frame_annotation.image.path)
|
||||
frame_data.image_path = image_path
|
||||
|
||||
if load_blobs and self.load_images:
|
||||
(
|
||||
frame_data.image_rgb,
|
||||
frame_data.image_path,
|
||||
) = self._load_images(frame_annotation, frame_data.fg_probability)
|
||||
if image_path is None:
|
||||
raise ValueError("Image path is required to load images.")
|
||||
|
||||
if load_blobs and self.load_depths and frame_annotation.depth is not None:
|
||||
image_np = load_image(self._local_path(image_path))
|
||||
frame_data.image_rgb = self._postprocess_image(
|
||||
image_np, frame_annotation.image.size, frame_data.fg_probability
|
||||
)
|
||||
|
||||
if (
|
||||
load_blobs
|
||||
and self.load_depths
|
||||
and frame_annotation.depth is not None
|
||||
and frame_annotation.depth.path is not None
|
||||
):
|
||||
(
|
||||
frame_data.depth_map,
|
||||
frame_data.depth_path,
|
||||
@@ -652,25 +663,22 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
return fg_probability, full_path
|
||||
|
||||
def _load_images(
|
||||
def _postprocess_image(
|
||||
self,
|
||||
entry: types.FrameAnnotation,
|
||||
image_np: np.ndarray,
|
||||
image_size: Tuple[int, int],
|
||||
fg_probability: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, str]:
|
||||
assert self.dataset_root is not None and entry.image is not None
|
||||
path = os.path.join(self.dataset_root, entry.image.path)
|
||||
image_rgb = load_image(self._local_path(path))
|
||||
) -> torch.Tensor:
|
||||
image_rgb = safe_as_tensor(image_np, torch.float)
|
||||
|
||||
if image_rgb.shape[-2:] != entry.image.size:
|
||||
raise ValueError(
|
||||
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
|
||||
)
|
||||
if image_rgb.shape[-2:] != image_size:
|
||||
raise ValueError(f"bad image size: {image_rgb.shape[-2:]} vs {image_size}!")
|
||||
|
||||
if self.mask_images:
|
||||
assert fg_probability is not None
|
||||
image_rgb *= fg_probability
|
||||
|
||||
return image_rgb, path
|
||||
return image_rgb
|
||||
|
||||
def _load_mask_depth(
|
||||
self,
|
||||
@@ -678,18 +686,19 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
fg_probability: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
||||
entry_depth = entry.depth
|
||||
assert self.dataset_root is not None and entry_depth is not None
|
||||
path = os.path.join(self.dataset_root, entry_depth.path)
|
||||
dataset_root = self.dataset_root
|
||||
assert dataset_root is not None
|
||||
assert entry_depth is not None and entry_depth.path is not None
|
||||
path = os.path.join(dataset_root, entry_depth.path)
|
||||
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
||||
|
||||
if self.mask_depths:
|
||||
assert fg_probability is not None
|
||||
depth_map *= fg_probability
|
||||
|
||||
if self.load_depth_masks:
|
||||
assert entry_depth.mask_path is not None
|
||||
# pyre-ignore
|
||||
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
|
||||
mask_path = entry_depth.mask_path
|
||||
if self.load_depth_masks and mask_path is not None:
|
||||
mask_path = os.path.join(dataset_root, mask_path)
|
||||
depth_mask = load_depth_mask(self._local_path(mask_path))
|
||||
else:
|
||||
depth_mask = torch.ones_like(depth_map)
|
||||
@@ -745,6 +754,16 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
return path
|
||||
return self.path_manager.get_local_path(path)
|
||||
|
||||
def _exists_in_dataset_root(self, relpath) -> bool:
|
||||
if not self.dataset_root:
|
||||
return False
|
||||
|
||||
full_path = os.path.join(self.dataset_root, relpath)
|
||||
if self.path_manager is None:
|
||||
return os.path.exists(full_path)
|
||||
else:
|
||||
return self.path_manager.exists(full_path)
|
||||
|
||||
|
||||
@registry.register
|
||||
class FrameDataBuilder(GenericWorkaround, GenericFrameDataBuilder[FrameData]):
|
||||
|
||||
@@ -210,6 +210,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
seq, frame = self._index.index[frame_idx]
|
||||
else:
|
||||
seq, frame, *rest = frame_idx
|
||||
if isinstance(frame, torch.LongTensor):
|
||||
frame = frame.item()
|
||||
|
||||
if (seq, frame) not in self._index.index:
|
||||
raise IndexError(
|
||||
f"Sequence-frame index {frame_idx} not found; was it filtered out?"
|
||||
|
||||
@@ -225,19 +225,23 @@ def resize_image(
|
||||
return imre_, minscale, mask
|
||||
|
||||
|
||||
def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
|
||||
im = np.atleast_3d(image).transpose((2, 0, 1))
|
||||
return im.astype(np.float32) / 255.0
|
||||
|
||||
|
||||
def load_image(path: str) -> np.ndarray:
|
||||
with Image.open(path) as pil_im:
|
||||
im = np.array(pil_im.convert("RGB"))
|
||||
im = im.transpose((2, 0, 1))
|
||||
im = im.astype(np.float32) / 255.0
|
||||
return im
|
||||
|
||||
return transpose_normalize_image(im)
|
||||
|
||||
|
||||
def load_mask(path: str) -> np.ndarray:
|
||||
with Image.open(path) as pil_im:
|
||||
mask = np.array(pil_im)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
return mask[None] # fake feature channel
|
||||
|
||||
return transpose_normalize_image(mask)
|
||||
|
||||
|
||||
def load_depth(path: str, scale_adjustment: float) -> np.ndarray:
|
||||
|
||||
Reference in New Issue
Block a user