From 597e0259dc43bf4903e9c99f5d61410c1ad75b78 Mon Sep 17 00:00:00 2001 From: Darijan Gudelj Date: Mon, 1 Aug 2022 10:23:34 -0700 Subject: [PATCH] Made eval_batches be set inside the __init__ Summary: Made eval_batches be set in call to `__init__` not after the construction as they were before Reviewed By: bottler Differential Revision: D38275943 fbshipit-source-id: 32737401d1ddd16c284e1851b7a91f8b041c406f --- .../implicitron/dataset/json_index_dataset.py | 18 ++++++++++++++++++ .../dataset/json_index_dataset_map_provider.py | 13 ++++++------- .../json_index_dataset_map_provider_v2.py | 10 ++++++++++ 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/pytorch3d/implicitron/dataset/json_index_dataset.py b/pytorch3d/implicitron/dataset/json_index_dataset.py index 2df567d0..115514da 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset.py @@ -112,6 +112,11 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): eval_batches: A list of batches that form the evaluation set; list of batch-sized lists of indices corresponding to __getitem__ of this class, thus it can be used directly as a batch sampler. + eval_batch_index: + ( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] ) + A list of batches of frames described as (sequence_name, frame_idx) + that can form the evaluation set, `eval_batches` will be set from this. + """ frame_annotations_type: ClassVar[ @@ -147,6 +152,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): seed: int = 0 sort_frames: bool = False eval_batches: Any = None + eval_batch_index: Any = None # frame_annots: List[FrameAnnotsEntry] = field(init=False) # seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False) @@ -159,8 +165,20 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): self._sort_frames() self._load_subset_lists() self._filter_db() # also computes sequence indices + self._extract_and_set_eval_batches() logger.info(str(self)) + def _extract_and_set_eval_batches(self): + """ + Sets eval_batches based on input eval_batch_index. + """ + if self.eval_batch_index is not None: + if self.eval_batches is not None: + raise ValueError( + "Cannot define both eval_batch_index and eval_batches." + ) + self.eval_batches = self.seq_frame_index_to_dataset_index() + def is_filtered(self): """ Returns `True` in case the dataset has been filtered and thus some frame annotations diff --git a/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py index 06f70231..23da875f 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider.py @@ -57,6 +57,7 @@ _CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "") _NEED_CONTROL: Tuple[str, ...] = ( "dataset_root", "eval_batches", + "eval_batch_index", "n_frames_per_sequence", "path_manager", "pick_sequence", @@ -212,6 +213,10 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] ] # overwrite the restrict_sequence_name restrict_sequence_name = [eval_sequence_name] + if len(restrict_sequence_name) > 0: + eval_batch_index = [ + b for b in eval_batch_index if b[0][0] in restrict_sequence_name + ] dataset_type: Type[JsonIndexDataset] = registry.get( JsonIndexDataset, self.dataset_class_type @@ -239,15 +244,9 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13] n_frames_per_sequence=-1, subsets=set_names_mapping["test"], pick_sequence=restrict_sequence_name, + eval_batch_index=eval_batch_index, **common_kwargs, ) - if len(restrict_sequence_name) > 0: - eval_batch_index = [ - b for b in eval_batch_index if b[0][0] in restrict_sequence_name - ] - test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index( - eval_batch_index - ) dataset_map = DatasetMap( train=train_dataset, val=val_dataset, test=test_dataset ) diff --git a/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py index 6b96de65..cbd1cf33 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py @@ -11,6 +11,7 @@ import os import warnings from typing import Dict, List, Optional, Type +from omegaconf import DictConfig, open_dict from pytorch3d.implicitron.dataset.dataset_map_provider import ( DatasetMap, DatasetMapProviderBase, @@ -269,6 +270,15 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13] train=train_dataset, val=val_dataset, test=test_dataset ) + @classmethod + def dataset_tweak_args(cls, type, args: DictConfig) -> None: + """ + Called by get_default_args(JsonIndexDatasetMapProviderV2) to + not expose certain fields of each dataset class. + """ + with open_dict(args): + del args["eval_batch_index"] + def create_dataset(self): # The dataset object is created inside `self.get_dataset_map` pass