Refactor: FrameDataBuilder is more extensible.

Summary:
This is mostly a refactoring diff to reduce friction in extending the frame data.

Slight functional changes: dataset getitem now accepts (seq_name, frame_number_as_singleton_tensor) as a non-advertised feature. Otherwise this code crashes:
```
item = dataset[0]
dataset[item.sequence_name, item.frame_number]
```

Reviewed By: bottler

Differential Revision: D45780175

fbshipit-source-id: 75b8e8d3dabed954a804310abdbd8ab44a8dea29
This commit is contained in:
Roman Shapovalov 2023-05-17 10:38:34 -07:00 committed by Facebook GitHub Bot
parent d08fe6d45a
commit b0462598ac
5 changed files with 102 additions and 41 deletions

View File

@ -132,6 +132,11 @@ class TestExperiment(unittest.TestCase):
# Check that the default config values, defined by Experiment and its # Check that the default config values, defined by Experiment and its
# members, is what we expect it to be. # members, is what we expect it to be.
cfg = OmegaConf.structured(experiment.Experiment) cfg = OmegaConf.structured(experiment.Experiment)
# the following removes the possible effect of env variables
ds_arg = cfg.data_source_ImplicitronDataSource_args
ds_arg.dataset_map_provider_JsonIndexDatasetMapProvider_args.dataset_root = ""
ds_arg.dataset_map_provider_JsonIndexDatasetMapProviderV2_args.dataset_root = ""
cfg.training_loop_ImplicitronTrainingLoop_args.visdom_port = 8097
yaml = OmegaConf.to_yaml(cfg, sort_keys=False) yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
if DEBUG: if DEBUG:
(DATA_DIR / "experiment.yaml").write_text(yaml) (DATA_DIR / "experiment.yaml").write_text(yaml)

View File

@ -203,7 +203,10 @@ class FrameData(Mapping[str, Any]):
when no image has been loaded) when no image has been loaded)
""" """
if self.bbox_xywh is None: if self.bbox_xywh is None:
raise ValueError("Attempted cropping by metadata with empty bounding box") raise ValueError(
"Attempted cropping by metadata with empty bounding box. Consider either"
" to remove_empty_masks or turn off box_crop in the dataset config."
)
if not self._uncropped: if not self._uncropped:
raise ValueError( raise ValueError(
@ -528,12 +531,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
"Make sure it is set in either FrameDataBuilder or Dataset params." "Make sure it is set in either FrameDataBuilder or Dataset params."
) )
if self.path_manager is None: if load_any_blob and not self._exists_in_dataset_root(""):
dataset_root_exists = os.path.isdir(self.dataset_root) # pyre-ignore
else:
dataset_root_exists = self.path_manager.isdir(self.dataset_root)
if load_any_blob and not dataset_root_exists:
raise ValueError( raise ValueError(
f"dataset_root is passed but {self.dataset_root} does not exist." f"dataset_root is passed but {self.dataset_root} does not exist."
) )
@ -604,14 +602,27 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
frame_data.image_size_hw = image_size_hw # original image size frame_data.image_size_hw = image_size_hw # original image size
# image size after crop/resize # image size after crop/resize
frame_data.effective_image_size_hw = image_size_hw frame_data.effective_image_size_hw = image_size_hw
image_path = None
dataset_root = self.dataset_root
if frame_annotation.image.path is not None and dataset_root is not None:
image_path = os.path.join(dataset_root, frame_annotation.image.path)
frame_data.image_path = image_path
if load_blobs and self.load_images: if load_blobs and self.load_images:
( if image_path is None:
frame_data.image_rgb, raise ValueError("Image path is required to load images.")
frame_data.image_path,
) = self._load_images(frame_annotation, frame_data.fg_probability)
if load_blobs and self.load_depths and frame_annotation.depth is not None: image_np = load_image(self._local_path(image_path))
frame_data.image_rgb = self._postprocess_image(
image_np, frame_annotation.image.size, frame_data.fg_probability
)
if (
load_blobs
and self.load_depths
and frame_annotation.depth is not None
and frame_annotation.depth.path is not None
):
( (
frame_data.depth_map, frame_data.depth_map,
frame_data.depth_path, frame_data.depth_path,
@ -652,25 +663,22 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
return fg_probability, full_path return fg_probability, full_path
def _load_images( def _postprocess_image(
self, self,
entry: types.FrameAnnotation, image_np: np.ndarray,
image_size: Tuple[int, int],
fg_probability: Optional[torch.Tensor], fg_probability: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str]: ) -> torch.Tensor:
assert self.dataset_root is not None and entry.image is not None image_rgb = safe_as_tensor(image_np, torch.float)
path = os.path.join(self.dataset_root, entry.image.path)
image_rgb = load_image(self._local_path(path))
if image_rgb.shape[-2:] != entry.image.size: if image_rgb.shape[-2:] != image_size:
raise ValueError( raise ValueError(f"bad image size: {image_rgb.shape[-2:]} vs {image_size}!")
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
)
if self.mask_images: if self.mask_images:
assert fg_probability is not None assert fg_probability is not None
image_rgb *= fg_probability image_rgb *= fg_probability
return image_rgb, path return image_rgb
def _load_mask_depth( def _load_mask_depth(
self, self,
@ -678,18 +686,19 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
fg_probability: Optional[torch.Tensor], fg_probability: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, str, torch.Tensor]: ) -> Tuple[torch.Tensor, str, torch.Tensor]:
entry_depth = entry.depth entry_depth = entry.depth
assert self.dataset_root is not None and entry_depth is not None dataset_root = self.dataset_root
path = os.path.join(self.dataset_root, entry_depth.path) assert dataset_root is not None
assert entry_depth is not None and entry_depth.path is not None
path = os.path.join(dataset_root, entry_depth.path)
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment) depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
if self.mask_depths: if self.mask_depths:
assert fg_probability is not None assert fg_probability is not None
depth_map *= fg_probability depth_map *= fg_probability
if self.load_depth_masks: mask_path = entry_depth.mask_path
assert entry_depth.mask_path is not None if self.load_depth_masks and mask_path is not None:
# pyre-ignore mask_path = os.path.join(dataset_root, mask_path)
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
depth_mask = load_depth_mask(self._local_path(mask_path)) depth_mask = load_depth_mask(self._local_path(mask_path))
else: else:
depth_mask = torch.ones_like(depth_map) depth_mask = torch.ones_like(depth_map)
@ -745,6 +754,16 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
return path return path
return self.path_manager.get_local_path(path) return self.path_manager.get_local_path(path)
def _exists_in_dataset_root(self, relpath) -> bool:
if not self.dataset_root:
return False
full_path = os.path.join(self.dataset_root, relpath)
if self.path_manager is None:
return os.path.exists(full_path)
else:
return self.path_manager.exists(full_path)
@registry.register @registry.register
class FrameDataBuilder(GenericWorkaround, GenericFrameDataBuilder[FrameData]): class FrameDataBuilder(GenericWorkaround, GenericFrameDataBuilder[FrameData]):

View File

@ -210,6 +210,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
seq, frame = self._index.index[frame_idx] seq, frame = self._index.index[frame_idx]
else: else:
seq, frame, *rest = frame_idx seq, frame, *rest = frame_idx
if isinstance(frame, torch.LongTensor):
frame = frame.item()
if (seq, frame) not in self._index.index: if (seq, frame) not in self._index.index:
raise IndexError( raise IndexError(
f"Sequence-frame index {frame_idx} not found; was it filtered out?" f"Sequence-frame index {frame_idx} not found; was it filtered out?"

View File

@ -225,19 +225,23 @@ def resize_image(
return imre_, minscale, mask return imre_, minscale, mask
def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
im = np.atleast_3d(image).transpose((2, 0, 1))
return im.astype(np.float32) / 255.0
def load_image(path: str) -> np.ndarray: def load_image(path: str) -> np.ndarray:
with Image.open(path) as pil_im: with Image.open(path) as pil_im:
im = np.array(pil_im.convert("RGB")) im = np.array(pil_im.convert("RGB"))
im = im.transpose((2, 0, 1))
im = im.astype(np.float32) / 255.0 return transpose_normalize_image(im)
return im
def load_mask(path: str) -> np.ndarray: def load_mask(path: str) -> np.ndarray:
with Image.open(path) as pil_im: with Image.open(path) as pil_im:
mask = np.array(pil_im) mask = np.array(pil_im)
mask = mask.astype(np.float32) / 255.0
return mask[None] # fake feature channel return transpose_normalize_image(mask)
def load_depth(path: str, scale_adjustment: float) -> np.ndarray: def load_depth(path: str, scale_adjustment: float) -> np.ndarray:

View File

@ -25,6 +25,7 @@ from pytorch3d.implicitron.dataset.utils import (
load_image, load_image,
load_mask, load_mask,
safe_as_tensor, safe_as_tensor,
transpose_normalize_image,
) )
from pytorch3d.implicitron.tools.config import get_default_args from pytorch3d.implicitron.tools.config import get_default_args
from pytorch3d.renderer.cameras import PerspectiveCameras from pytorch3d.renderer.cameras import PerspectiveCameras
@ -123,14 +124,15 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
# assert bboxes shape # assert bboxes shape
self.assertEqual(self.frame_data.bbox_xywh.shape, torch.Size([4])) self.assertEqual(self.frame_data.bbox_xywh.shape, torch.Size([4]))
( image_path = os.path.join(
self.frame_data.image_rgb, self.frame_data_builder.dataset_root, self.frame_annotation.image.path
self.frame_data.image_path,
) = self.frame_data_builder._load_images(
self.frame_annotation, self.frame_data.fg_probability
) )
self.assertEqual(type(self.frame_data.image_rgb), np.ndarray) image_np = load_image(self.frame_data_builder._local_path(image_path))
self.assertIsNotNone(self.frame_data.image_path) self.assertIsInstance(image_np, np.ndarray)
self.frame_data.image_rgb = self.frame_data_builder._postprocess_image(
image_np, self.frame_annotation.image.size, self.frame_data.fg_probability
)
self.assertIsInstance(self.frame_data.image_rgb, torch.Tensor)
( (
self.frame_data.depth_map, self.frame_data.depth_map,
@ -184,6 +186,34 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
) )
self.assertEqual(type(self.frame_data.camera), PerspectiveCameras) self.assertEqual(type(self.frame_data.camera), PerspectiveCameras)
def test_transpose_normalize_image(self):
def inverse_transpose_normalize_image(image: np.ndarray) -> np.ndarray:
im = image * 255.0
return im.transpose((1, 2, 0)).astype(np.uint8)
# Test 2D input
input_image = np.array(
[[10, 20, 30], [40, 50, 60], [70, 80, 90]], dtype=np.uint8
)
expected_input = inverse_transpose_normalize_image(
transpose_normalize_image(input_image)
)
self.assertClose(input_image[..., None], expected_input)
# Test 3D input
input_image = np.array(
[
[[10, 20, 30], [40, 50, 60], [70, 80, 90]],
[[100, 110, 120], [130, 140, 150], [160, 170, 180]],
[[190, 200, 210], [220, 230, 240], [250, 255, 255]],
],
dtype=np.uint8,
)
expected_input = inverse_transpose_normalize_image(
transpose_normalize_image(input_image)
)
self.assertClose(input_image, expected_input)
def test_load_image(self): def test_load_image(self):
path = os.path.join(self.dataset_root, self.frame_annotation.image.path) path = os.path.join(self.dataset_root, self.frame_annotation.image.path)
local_path = self.path_manager.get_local_path(path) local_path = self.path_manager.get_local_path(path)