From 1fb268dea682d4b24df0fe3644868ec499cca8a7 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Fri, 10 Jun 2022 12:22:46 -0700 Subject: [PATCH] allow get_default_args(JsonIndexDataset) Summary: Changes to JsonIndexDataset to make it fit with OmegaConf.structured. Also match some default values to what the provider defaults to. Reviewed By: davnov134 Differential Revision: D36666704 fbshipit-source-id: 65b059a1dbaa240ce85c3e8762b7c3db3b5a6e75 --- pytorch3d/implicitron/dataset/dataset_base.py | 32 +++++++++++-- .../implicitron/dataset/json_index_dataset.py | 47 +++++++++++++------ pytorch3d/implicitron/dataset/visualize.py | 1 + tests/implicitron/test_data_source.py | 28 +++++++++++ tests/implicitron/test_evaluation.py | 1 + 5 files changed, 90 insertions(+), 19 deletions(-) diff --git a/pytorch3d/implicitron/dataset/dataset_base.py b/pytorch3d/implicitron/dataset/dataset_base.py index 83859f6b..11a0cbae 100644 --- a/pytorch3d/implicitron/dataset/dataset_base.py +++ b/pytorch3d/implicitron/dataset/dataset_base.py @@ -8,7 +8,6 @@ from collections import defaultdict from dataclasses import dataclass, field, fields from typing import ( Any, - Dict, Iterable, Iterator, List, @@ -182,8 +181,28 @@ class FrameData(Mapping[str, Any]): return torch.utils.data._utils.collate.default_collate(batch) +class _GenericWorkaround: + """ + OmegaConf.structured has a weirdness when you try to apply + it to a dataclass whose first base class is a Generic which is not + Dict. The issue is with a function called get_dict_key_value_types + in omegaconf/_utils.py. + For example this fails: + + @dataclass(eq=False) + class D(torch.utils.data.Dataset[int]): + a: int = 3 + + OmegaConf.structured(D) + + We avoid the problem by adding this class as an extra base class. + """ + + pass + + @dataclass(eq=False) -class DatasetBase(torch.utils.data.Dataset[FrameData]): +class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]): """ Base class to describe a dataset to be used with Implicitron. @@ -195,10 +214,11 @@ class DatasetBase(torch.utils.data.Dataset[FrameData]): which will describe one frame in one sequence. """ - # Maps sequence name to the sequence's global frame indices. + # _seq_to_idx is a member which implementations can define. + # It maps sequence name to the sequence's global frame indices. # It is used for the default implementations of some functions in this class. - # Implementations which override them are free to ignore this member. - _seq_to_idx: Dict[str, List[int]] = field(init=False) + # Implementations which override them are free to ignore it. + # _seq_to_idx: Dict[str, List[int]] = field(init=False) def __len__(self) -> int: raise NotImplementedError @@ -232,6 +252,7 @@ class DatasetBase(torch.utils.data.Dataset[FrameData]): def sequence_names(self) -> Iterable[str]: """Returns an iterator over sequence names in the dataset.""" + # pyre-ignore[16] return self._seq_to_idx.keys() def sequence_frames_in_order( @@ -250,6 +271,7 @@ class DatasetBase(torch.utils.data.Dataset[FrameData]): `dataset_idx` is the index within the dataset. `None` timestamps are replaced with 0s. """ + # pyre-ignore[16] seq_frame_indices = self._seq_to_idx[seq_name] nos_timestamps = self.get_frame_numbers_and_timestamps(seq_frame_indices) diff --git a/pytorch3d/implicitron/dataset/json_index_dataset.py b/pytorch3d/implicitron/dataset/json_index_dataset.py index bd8d1711..1755ba07 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset.py @@ -13,12 +13,12 @@ import os import random import warnings from collections import defaultdict -from dataclasses import dataclass, field +from dataclasses import dataclass from itertools import islice from pathlib import Path from typing import ( + Any, ClassVar, - Dict, List, Optional, Sequence, @@ -30,7 +30,6 @@ from typing import ( import numpy as np import torch -from iopath.common.file_io import PathManager from PIL import Image from pytorch3d.io import IO from pytorch3d.renderer.cameras import PerspectiveCameras @@ -116,7 +115,7 @@ class JsonIndexDataset(DatasetBase): Type[types.FrameAnnotation] ] = types.FrameAnnotation - path_manager: Optional[PathManager] = None + path_manager: Any = None frame_annotations_file: str = "" sequence_annotations_file: str = "" subset_lists_file: str = "" @@ -135,18 +134,18 @@ class JsonIndexDataset(DatasetBase): max_points: int = 0 mask_images: bool = False mask_depths: bool = False - image_height: Optional[int] = 256 - image_width: Optional[int] = 256 - box_crop: bool = False + image_height: Optional[int] = 800 + image_width: Optional[int] = 800 + box_crop: bool = True box_crop_mask_thr: float = 0.4 - box_crop_context: float = 1.0 - remove_empty_masks: bool = False + box_crop_context: float = 0.3 + remove_empty_masks: bool = True n_frames_per_sequence: int = -1 seed: int = 0 sort_frames: bool = False - eval_batches: Optional[List[List[int]]] = None - frame_annots: List[FrameAnnotsEntry] = field(init=False) - seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False) + eval_batches: Any = None + # frame_annots: List[FrameAnnotsEntry] = field(init=False) + # seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False) def __post_init__(self) -> None: # pyre-fixme[16]: `JsonIndexDataset` has no attribute `subset_to_image_path`. @@ -172,9 +171,11 @@ class JsonIndexDataset(DatasetBase): # TODO: check the frame numbers are unique _dataset_seq_frame_n_index = { seq: { + # pyre-ignore[16] self.frame_annots[idx]["frame_annotation"].frame_number: idx for idx in seq_idx } + # pyre-ignore[16] for seq, seq_idx in self._seq_to_idx.items() } @@ -184,6 +185,7 @@ class JsonIndexDataset(DatasetBase): # Check that the loaded frame path is consistent # with the one stored in self.frame_annots. assert os.path.normpath( + # pyre-ignore[16] self.frame_annots[idx]["frame_annotation"].image.path ) == os.path.normpath( path @@ -194,19 +196,23 @@ class JsonIndexDataset(DatasetBase): return batches_idx def __str__(self) -> str: + # pyre-ignore[16] return f"JsonIndexDataset #frames={len(self.frame_annots)}" def __len__(self) -> int: + # pyre-ignore[16] return len(self.frame_annots) def _get_frame_type(self, entry: FrameAnnotsEntry) -> Optional[str]: return entry["subset"] def __getitem__(self, index) -> FrameData: + # pyre-ignore[16] if index >= len(self.frame_annots): raise IndexError(f"index {index} out of range {len(self.frame_annots)}") entry = self.frame_annots[index]["frame_annotation"] + # pyre-ignore[16] point_cloud = self.seq_annots[entry.sequence_name].point_cloud frame_data = FrameData( frame_number=_safe_as_tensor(entry.frame_number, torch.long), @@ -441,6 +447,7 @@ class JsonIndexDataset(DatasetBase): ) if not frame_annots_list: raise ValueError("Empty dataset!") + # pyre-ignore[16] self.frame_annots = [ FrameAnnotsEntry(frame_annotation=a, subset=None) for a in frame_annots_list ] @@ -452,6 +459,7 @@ class JsonIndexDataset(DatasetBase): seq_annots = types.load_dataclass(zipfile, List[types.SequenceAnnotation]) if not seq_annots: raise ValueError("Empty sequences file!") + # pyre-ignore[16] self.seq_annots = {entry.sequence_name: entry for entry in seq_annots} def _load_subset_lists(self) -> None: @@ -467,7 +475,7 @@ class JsonIndexDataset(DatasetBase): for subset, frames in subset_to_seq_frame.items() for _, _, path in frames } - + # pyre-ignore[16] for frame in self.frame_annots: frame["subset"] = frame_path_to_subset.get( frame["frame_annotation"].image.path, None @@ -480,6 +488,7 @@ class JsonIndexDataset(DatasetBase): def _sort_frames(self) -> None: # Sort frames to have them grouped by sequence, ordered by timestamp + # pyre-ignore[16] self.frame_annots = sorted( self.frame_annots, key=lambda f: ( @@ -491,6 +500,7 @@ class JsonIndexDataset(DatasetBase): def _filter_db(self) -> None: if self.remove_empty_masks: logger.info("Removing images with empty masks.") + # pyre-ignore[16] old_len = len(self.frame_annots) msg = "remove_empty_masks needs every MaskAnnotation.mass to be set." @@ -531,6 +541,7 @@ class JsonIndexDataset(DatasetBase): if len(self.limit_category_to) > 0: logger.info(f"Limiting dataset to categories: {self.limit_category_to}") + # pyre-ignore[16] self.seq_annots = { name: entry for name, entry in self.seq_annots.items() @@ -568,6 +579,7 @@ class JsonIndexDataset(DatasetBase): if self.n_frames_per_sequence > 0: logger.info(f"Taking max {self.n_frames_per_sequence} per sequence.") keep_idx = [] + # pyre-ignore[16] for seq, seq_indices in self._seq_to_idx.items(): # infer the seed from the sequence name, this is reproducible # and makes the selection differ for different sequences @@ -597,14 +609,20 @@ class JsonIndexDataset(DatasetBase): self._invalidate_seq_to_idx() if filter_seq_annots: + # pyre-ignore[16] self.seq_annots = { - k: v for k, v in self.seq_annots.items() if k in self._seq_to_idx + k: v + for k, v in self.seq_annots.items() + # pyre-ignore[16] + if k in self._seq_to_idx } def _invalidate_seq_to_idx(self) -> None: seq_to_idx = defaultdict(list) + # pyre-ignore[16] for idx, entry in enumerate(self.frame_annots): seq_to_idx[entry["frame_annotation"].sequence_name].append(idx) + # pyre-ignore[16] self._seq_to_idx = seq_to_idx def _resize_image( @@ -644,6 +662,7 @@ class JsonIndexDataset(DatasetBase): ) -> List[Tuple[int, float]]: out: List[Tuple[int, float]] = [] for idx in idxs: + # pyre-ignore[16] frame_annotation = self.frame_annots[idx]["frame_annotation"] out.append( (frame_annotation.frame_number, frame_annotation.frame_timestamp) diff --git a/pytorch3d/implicitron/dataset/visualize.py b/pytorch3d/implicitron/dataset/visualize.py index 331ee892..8a4be469 100644 --- a/pytorch3d/implicitron/dataset/visualize.py +++ b/pytorch3d/implicitron/dataset/visualize.py @@ -44,6 +44,7 @@ def get_implicitron_sequence_pointcloud( sequence_entries = [ ei for ei in sequence_entries + # pyre-ignore[16] if dataset.frame_annots[ei]["frame_annotation"].sequence_name == sequence_name ] diff --git a/tests/implicitron/test_data_source.py b/tests/implicitron/test_data_source.py index 4d664495..d61957b0 100644 --- a/tests/implicitron/test_data_source.py +++ b/tests/implicitron/test_data_source.py @@ -9,6 +9,7 @@ import unittest from omegaconf import OmegaConf from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource +from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset from pytorch3d.implicitron.tools.config import get_default_args from tests.common_testing import get_tests_dir @@ -20,6 +21,33 @@ class TestDataSource(unittest.TestCase): def setUp(self): self.maxDiff = None + def _test_omegaconf_generic_failure(self): + # OmegaConf possible bug - this is why we need _GenericWorkaround + from dataclasses import dataclass + + import torch + + @dataclass + class D(torch.utils.data.Dataset[int]): + a: int = 3 + + OmegaConf.structured(D) + + def _test_omegaconf_ListList(self): + # Demo that OmegaConf doesn't support nested lists + from dataclasses import dataclass + from typing import Sequence + + @dataclass + class A: + a: Sequence[Sequence[int]] = ((32,),) + + OmegaConf.structured(A) + + def test_JsonIndexDataset_args(self): + # test that JsonIndexDataset works with get_default_args + get_default_args(JsonIndexDataset) + def test_one(self): with unittest.mock.patch.dict(os.environ, {"CO3D_DATASET_ROOT": ""}): cfg = get_default_args(ImplicitronDataSource) diff --git a/tests/implicitron/test_evaluation.py b/tests/implicitron/test_evaluation.py index f8a7ddfd..ed43dca6 100644 --- a/tests/implicitron/test_evaluation.py +++ b/tests/implicitron/test_evaluation.py @@ -51,6 +51,7 @@ class TestEvaluation(unittest.TestCase): image_height=self.image_size, image_width=self.image_size, box_crop=True, + remove_empty_masks=False, path_manager=path_manager, ) self.bg_color = (0.0, 0.0, 0.0)