mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Refactor: FrameDataBuilder is more extensible.
Summary: This is mostly a refactoring diff to reduce friction in extending the frame data. Slight functional changes: dataset getitem now accepts (seq_name, frame_number_as_singleton_tensor) as a non-advertised feature. Otherwise this code crashes: ``` item = dataset[0] dataset[item.sequence_name, item.frame_number] ``` Reviewed By: bottler Differential Revision: D45780175 fbshipit-source-id: 75b8e8d3dabed954a804310abdbd8ab44a8dea29
This commit is contained in:
parent
d08fe6d45a
commit
b0462598ac
@ -132,6 +132,11 @@ class TestExperiment(unittest.TestCase):
|
||||
# Check that the default config values, defined by Experiment and its
|
||||
# members, is what we expect it to be.
|
||||
cfg = OmegaConf.structured(experiment.Experiment)
|
||||
# the following removes the possible effect of env variables
|
||||
ds_arg = cfg.data_source_ImplicitronDataSource_args
|
||||
ds_arg.dataset_map_provider_JsonIndexDatasetMapProvider_args.dataset_root = ""
|
||||
ds_arg.dataset_map_provider_JsonIndexDatasetMapProviderV2_args.dataset_root = ""
|
||||
cfg.training_loop_ImplicitronTrainingLoop_args.visdom_port = 8097
|
||||
yaml = OmegaConf.to_yaml(cfg, sort_keys=False)
|
||||
if DEBUG:
|
||||
(DATA_DIR / "experiment.yaml").write_text(yaml)
|
||||
|
@ -203,7 +203,10 @@ class FrameData(Mapping[str, Any]):
|
||||
when no image has been loaded)
|
||||
"""
|
||||
if self.bbox_xywh is None:
|
||||
raise ValueError("Attempted cropping by metadata with empty bounding box")
|
||||
raise ValueError(
|
||||
"Attempted cropping by metadata with empty bounding box. Consider either"
|
||||
" to remove_empty_masks or turn off box_crop in the dataset config."
|
||||
)
|
||||
|
||||
if not self._uncropped:
|
||||
raise ValueError(
|
||||
@ -528,12 +531,7 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
"Make sure it is set in either FrameDataBuilder or Dataset params."
|
||||
)
|
||||
|
||||
if self.path_manager is None:
|
||||
dataset_root_exists = os.path.isdir(self.dataset_root) # pyre-ignore
|
||||
else:
|
||||
dataset_root_exists = self.path_manager.isdir(self.dataset_root)
|
||||
|
||||
if load_any_blob and not dataset_root_exists:
|
||||
if load_any_blob and not self._exists_in_dataset_root(""):
|
||||
raise ValueError(
|
||||
f"dataset_root is passed but {self.dataset_root} does not exist."
|
||||
)
|
||||
@ -604,14 +602,27 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
frame_data.image_size_hw = image_size_hw # original image size
|
||||
# image size after crop/resize
|
||||
frame_data.effective_image_size_hw = image_size_hw
|
||||
image_path = None
|
||||
dataset_root = self.dataset_root
|
||||
if frame_annotation.image.path is not None and dataset_root is not None:
|
||||
image_path = os.path.join(dataset_root, frame_annotation.image.path)
|
||||
frame_data.image_path = image_path
|
||||
|
||||
if load_blobs and self.load_images:
|
||||
(
|
||||
frame_data.image_rgb,
|
||||
frame_data.image_path,
|
||||
) = self._load_images(frame_annotation, frame_data.fg_probability)
|
||||
if image_path is None:
|
||||
raise ValueError("Image path is required to load images.")
|
||||
|
||||
if load_blobs and self.load_depths and frame_annotation.depth is not None:
|
||||
image_np = load_image(self._local_path(image_path))
|
||||
frame_data.image_rgb = self._postprocess_image(
|
||||
image_np, frame_annotation.image.size, frame_data.fg_probability
|
||||
)
|
||||
|
||||
if (
|
||||
load_blobs
|
||||
and self.load_depths
|
||||
and frame_annotation.depth is not None
|
||||
and frame_annotation.depth.path is not None
|
||||
):
|
||||
(
|
||||
frame_data.depth_map,
|
||||
frame_data.depth_path,
|
||||
@ -652,25 +663,22 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
|
||||
return fg_probability, full_path
|
||||
|
||||
def _load_images(
|
||||
def _postprocess_image(
|
||||
self,
|
||||
entry: types.FrameAnnotation,
|
||||
image_np: np.ndarray,
|
||||
image_size: Tuple[int, int],
|
||||
fg_probability: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, str]:
|
||||
assert self.dataset_root is not None and entry.image is not None
|
||||
path = os.path.join(self.dataset_root, entry.image.path)
|
||||
image_rgb = load_image(self._local_path(path))
|
||||
) -> torch.Tensor:
|
||||
image_rgb = safe_as_tensor(image_np, torch.float)
|
||||
|
||||
if image_rgb.shape[-2:] != entry.image.size:
|
||||
raise ValueError(
|
||||
f"bad image size: {image_rgb.shape[-2:]} vs {entry.image.size}!"
|
||||
)
|
||||
if image_rgb.shape[-2:] != image_size:
|
||||
raise ValueError(f"bad image size: {image_rgb.shape[-2:]} vs {image_size}!")
|
||||
|
||||
if self.mask_images:
|
||||
assert fg_probability is not None
|
||||
image_rgb *= fg_probability
|
||||
|
||||
return image_rgb, path
|
||||
return image_rgb
|
||||
|
||||
def _load_mask_depth(
|
||||
self,
|
||||
@ -678,18 +686,19 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
fg_probability: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, str, torch.Tensor]:
|
||||
entry_depth = entry.depth
|
||||
assert self.dataset_root is not None and entry_depth is not None
|
||||
path = os.path.join(self.dataset_root, entry_depth.path)
|
||||
dataset_root = self.dataset_root
|
||||
assert dataset_root is not None
|
||||
assert entry_depth is not None and entry_depth.path is not None
|
||||
path = os.path.join(dataset_root, entry_depth.path)
|
||||
depth_map = load_depth(self._local_path(path), entry_depth.scale_adjustment)
|
||||
|
||||
if self.mask_depths:
|
||||
assert fg_probability is not None
|
||||
depth_map *= fg_probability
|
||||
|
||||
if self.load_depth_masks:
|
||||
assert entry_depth.mask_path is not None
|
||||
# pyre-ignore
|
||||
mask_path = os.path.join(self.dataset_root, entry_depth.mask_path)
|
||||
mask_path = entry_depth.mask_path
|
||||
if self.load_depth_masks and mask_path is not None:
|
||||
mask_path = os.path.join(dataset_root, mask_path)
|
||||
depth_mask = load_depth_mask(self._local_path(mask_path))
|
||||
else:
|
||||
depth_mask = torch.ones_like(depth_map)
|
||||
@ -745,6 +754,16 @@ class GenericFrameDataBuilder(FrameDataBuilderBase[FrameDataSubtype], ABC):
|
||||
return path
|
||||
return self.path_manager.get_local_path(path)
|
||||
|
||||
def _exists_in_dataset_root(self, relpath) -> bool:
|
||||
if not self.dataset_root:
|
||||
return False
|
||||
|
||||
full_path = os.path.join(self.dataset_root, relpath)
|
||||
if self.path_manager is None:
|
||||
return os.path.exists(full_path)
|
||||
else:
|
||||
return self.path_manager.exists(full_path)
|
||||
|
||||
|
||||
@registry.register
|
||||
class FrameDataBuilder(GenericWorkaround, GenericFrameDataBuilder[FrameData]):
|
||||
|
@ -210,6 +210,9 @@ class SqlIndexDataset(DatasetBase, ReplaceableBase): # pyre-ignore
|
||||
seq, frame = self._index.index[frame_idx]
|
||||
else:
|
||||
seq, frame, *rest = frame_idx
|
||||
if isinstance(frame, torch.LongTensor):
|
||||
frame = frame.item()
|
||||
|
||||
if (seq, frame) not in self._index.index:
|
||||
raise IndexError(
|
||||
f"Sequence-frame index {frame_idx} not found; was it filtered out?"
|
||||
|
@ -225,19 +225,23 @@ def resize_image(
|
||||
return imre_, minscale, mask
|
||||
|
||||
|
||||
def transpose_normalize_image(image: np.ndarray) -> np.ndarray:
|
||||
im = np.atleast_3d(image).transpose((2, 0, 1))
|
||||
return im.astype(np.float32) / 255.0
|
||||
|
||||
|
||||
def load_image(path: str) -> np.ndarray:
|
||||
with Image.open(path) as pil_im:
|
||||
im = np.array(pil_im.convert("RGB"))
|
||||
im = im.transpose((2, 0, 1))
|
||||
im = im.astype(np.float32) / 255.0
|
||||
return im
|
||||
|
||||
return transpose_normalize_image(im)
|
||||
|
||||
|
||||
def load_mask(path: str) -> np.ndarray:
|
||||
with Image.open(path) as pil_im:
|
||||
mask = np.array(pil_im)
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
return mask[None] # fake feature channel
|
||||
|
||||
return transpose_normalize_image(mask)
|
||||
|
||||
|
||||
def load_depth(path: str, scale_adjustment: float) -> np.ndarray:
|
||||
|
@ -25,6 +25,7 @@ from pytorch3d.implicitron.dataset.utils import (
|
||||
load_image,
|
||||
load_mask,
|
||||
safe_as_tensor,
|
||||
transpose_normalize_image,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import get_default_args
|
||||
from pytorch3d.renderer.cameras import PerspectiveCameras
|
||||
@ -123,14 +124,15 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
|
||||
# assert bboxes shape
|
||||
self.assertEqual(self.frame_data.bbox_xywh.shape, torch.Size([4]))
|
||||
|
||||
(
|
||||
self.frame_data.image_rgb,
|
||||
self.frame_data.image_path,
|
||||
) = self.frame_data_builder._load_images(
|
||||
self.frame_annotation, self.frame_data.fg_probability
|
||||
image_path = os.path.join(
|
||||
self.frame_data_builder.dataset_root, self.frame_annotation.image.path
|
||||
)
|
||||
self.assertEqual(type(self.frame_data.image_rgb), np.ndarray)
|
||||
self.assertIsNotNone(self.frame_data.image_path)
|
||||
image_np = load_image(self.frame_data_builder._local_path(image_path))
|
||||
self.assertIsInstance(image_np, np.ndarray)
|
||||
self.frame_data.image_rgb = self.frame_data_builder._postprocess_image(
|
||||
image_np, self.frame_annotation.image.size, self.frame_data.fg_probability
|
||||
)
|
||||
self.assertIsInstance(self.frame_data.image_rgb, torch.Tensor)
|
||||
|
||||
(
|
||||
self.frame_data.depth_map,
|
||||
@ -184,6 +186,34 @@ class TestFrameDataBuilder(TestCaseMixin, unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(type(self.frame_data.camera), PerspectiveCameras)
|
||||
|
||||
def test_transpose_normalize_image(self):
|
||||
def inverse_transpose_normalize_image(image: np.ndarray) -> np.ndarray:
|
||||
im = image * 255.0
|
||||
return im.transpose((1, 2, 0)).astype(np.uint8)
|
||||
|
||||
# Test 2D input
|
||||
input_image = np.array(
|
||||
[[10, 20, 30], [40, 50, 60], [70, 80, 90]], dtype=np.uint8
|
||||
)
|
||||
expected_input = inverse_transpose_normalize_image(
|
||||
transpose_normalize_image(input_image)
|
||||
)
|
||||
self.assertClose(input_image[..., None], expected_input)
|
||||
|
||||
# Test 3D input
|
||||
input_image = np.array(
|
||||
[
|
||||
[[10, 20, 30], [40, 50, 60], [70, 80, 90]],
|
||||
[[100, 110, 120], [130, 140, 150], [160, 170, 180]],
|
||||
[[190, 200, 210], [220, 230, 240], [250, 255, 255]],
|
||||
],
|
||||
dtype=np.uint8,
|
||||
)
|
||||
expected_input = inverse_transpose_normalize_image(
|
||||
transpose_normalize_image(input_image)
|
||||
)
|
||||
self.assertClose(input_image, expected_input)
|
||||
|
||||
def test_load_image(self):
|
||||
path = os.path.join(self.dataset_root, self.frame_annotation.image.path)
|
||||
local_path = self.path_manager.get_local_path(path)
|
||||
|
Loading…
x
Reference in New Issue
Block a user