mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Made eval_batches be set inside the __init__
Summary: Made eval_batches be set in call to `__init__` not after the construction as they were before Reviewed By: bottler Differential Revision: D38275943 fbshipit-source-id: 32737401d1ddd16c284e1851b7a91f8b041c406f
This commit is contained in:
parent
80fc0ee0b6
commit
597e0259dc
@ -112,6 +112,11 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
eval_batches: A list of batches that form the evaluation set;
|
||||
list of batch-sized lists of indices corresponding to __getitem__
|
||||
of this class, thus it can be used directly as a batch sampler.
|
||||
eval_batch_index:
|
||||
( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] )
|
||||
A list of batches of frames described as (sequence_name, frame_idx)
|
||||
that can form the evaluation set, `eval_batches` will be set from this.
|
||||
|
||||
"""
|
||||
|
||||
frame_annotations_type: ClassVar[
|
||||
@ -147,6 +152,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
seed: int = 0
|
||||
sort_frames: bool = False
|
||||
eval_batches: Any = None
|
||||
eval_batch_index: Any = None
|
||||
# frame_annots: List[FrameAnnotsEntry] = field(init=False)
|
||||
# seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
|
||||
|
||||
@ -159,8 +165,20 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
||||
self._sort_frames()
|
||||
self._load_subset_lists()
|
||||
self._filter_db() # also computes sequence indices
|
||||
self._extract_and_set_eval_batches()
|
||||
logger.info(str(self))
|
||||
|
||||
def _extract_and_set_eval_batches(self):
|
||||
"""
|
||||
Sets eval_batches based on input eval_batch_index.
|
||||
"""
|
||||
if self.eval_batch_index is not None:
|
||||
if self.eval_batches is not None:
|
||||
raise ValueError(
|
||||
"Cannot define both eval_batch_index and eval_batches."
|
||||
)
|
||||
self.eval_batches = self.seq_frame_index_to_dataset_index()
|
||||
|
||||
def is_filtered(self):
|
||||
"""
|
||||
Returns `True` in case the dataset has been filtered and thus some frame annotations
|
||||
|
@ -57,6 +57,7 @@ _CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
|
||||
_NEED_CONTROL: Tuple[str, ...] = (
|
||||
"dataset_root",
|
||||
"eval_batches",
|
||||
"eval_batch_index",
|
||||
"n_frames_per_sequence",
|
||||
"path_manager",
|
||||
"pick_sequence",
|
||||
@ -212,6 +213,10 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
]
|
||||
# overwrite the restrict_sequence_name
|
||||
restrict_sequence_name = [eval_sequence_name]
|
||||
if len(restrict_sequence_name) > 0:
|
||||
eval_batch_index = [
|
||||
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
|
||||
]
|
||||
|
||||
dataset_type: Type[JsonIndexDataset] = registry.get(
|
||||
JsonIndexDataset, self.dataset_class_type
|
||||
@ -239,15 +244,9 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
n_frames_per_sequence=-1,
|
||||
subsets=set_names_mapping["test"],
|
||||
pick_sequence=restrict_sequence_name,
|
||||
eval_batch_index=eval_batch_index,
|
||||
**common_kwargs,
|
||||
)
|
||||
if len(restrict_sequence_name) > 0:
|
||||
eval_batch_index = [
|
||||
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
|
||||
]
|
||||
test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index(
|
||||
eval_batch_index
|
||||
)
|
||||
dataset_map = DatasetMap(
|
||||
train=train_dataset, val=val_dataset, test=test_dataset
|
||||
)
|
||||
|
@ -11,6 +11,7 @@ import os
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from omegaconf import DictConfig, open_dict
|
||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||
DatasetMap,
|
||||
DatasetMapProviderBase,
|
||||
@ -269,6 +270,15 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
||||
train=train_dataset, val=val_dataset, test=test_dataset
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
|
||||
"""
|
||||
Called by get_default_args(JsonIndexDatasetMapProviderV2) to
|
||||
not expose certain fields of each dataset class.
|
||||
"""
|
||||
with open_dict(args):
|
||||
del args["eval_batch_index"]
|
||||
|
||||
def create_dataset(self):
|
||||
# The dataset object is created inside `self.get_dataset_map`
|
||||
pass
|
||||
|
Loading…
x
Reference in New Issue
Block a user