Made eval_batches be set inside the __init__

Summary: Made eval_batches be set in call to `__init__` not after the construction as they were before

Reviewed By: bottler

Differential Revision: D38275943

fbshipit-source-id: 32737401d1ddd16c284e1851b7a91f8b041c406f
This commit is contained in:
Darijan Gudelj 2022-08-01 10:23:34 -07:00 committed by Facebook GitHub Bot
parent 80fc0ee0b6
commit 597e0259dc
3 changed files with 34 additions and 7 deletions

View File

@ -112,6 +112,11 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
eval_batches: A list of batches that form the evaluation set;
list of batch-sized lists of indices corresponding to __getitem__
of this class, thus it can be used directly as a batch sampler.
eval_batch_index:
( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] )
A list of batches of frames described as (sequence_name, frame_idx)
that can form the evaluation set, `eval_batches` will be set from this.
"""
frame_annotations_type: ClassVar[
@ -147,6 +152,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
seed: int = 0
sort_frames: bool = False
eval_batches: Any = None
eval_batch_index: Any = None
# frame_annots: List[FrameAnnotsEntry] = field(init=False)
# seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
@ -159,8 +165,20 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
self._sort_frames()
self._load_subset_lists()
self._filter_db() # also computes sequence indices
self._extract_and_set_eval_batches()
logger.info(str(self))
def _extract_and_set_eval_batches(self):
"""
Sets eval_batches based on input eval_batch_index.
"""
if self.eval_batch_index is not None:
if self.eval_batches is not None:
raise ValueError(
"Cannot define both eval_batch_index and eval_batches."
)
self.eval_batches = self.seq_frame_index_to_dataset_index()
def is_filtered(self):
"""
Returns `True` in case the dataset has been filtered and thus some frame annotations

View File

@ -57,6 +57,7 @@ _CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
_NEED_CONTROL: Tuple[str, ...] = (
"dataset_root",
"eval_batches",
"eval_batch_index",
"n_frames_per_sequence",
"path_manager",
"pick_sequence",
@ -212,6 +213,10 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
]
# overwrite the restrict_sequence_name
restrict_sequence_name = [eval_sequence_name]
if len(restrict_sequence_name) > 0:
eval_batch_index = [
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
]
dataset_type: Type[JsonIndexDataset] = registry.get(
JsonIndexDataset, self.dataset_class_type
@ -239,15 +244,9 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
n_frames_per_sequence=-1,
subsets=set_names_mapping["test"],
pick_sequence=restrict_sequence_name,
eval_batch_index=eval_batch_index,
**common_kwargs,
)
if len(restrict_sequence_name) > 0:
eval_batch_index = [
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
]
test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index(
eval_batch_index
)
dataset_map = DatasetMap(
train=train_dataset, val=val_dataset, test=test_dataset
)

View File

@ -11,6 +11,7 @@ import os
import warnings
from typing import Dict, List, Optional, Type
from omegaconf import DictConfig, open_dict
from pytorch3d.implicitron.dataset.dataset_map_provider import (
DatasetMap,
DatasetMapProviderBase,
@ -269,6 +270,15 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
train=train_dataset, val=val_dataset, test=test_dataset
)
@classmethod
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
"""
Called by get_default_args(JsonIndexDatasetMapProviderV2) to
not expose certain fields of each dataset class.
"""
with open_dict(args):
del args["eval_batch_index"]
def create_dataset(self):
# The dataset object is created inside `self.get_dataset_map`
pass