mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 11:52:50 +08:00
Made eval_batches be set inside the __init__
Summary: Made eval_batches be set in call to `__init__` not after the construction as they were before Reviewed By: bottler Differential Revision: D38275943 fbshipit-source-id: 32737401d1ddd16c284e1851b7a91f8b041c406f
This commit is contained in:
parent
80fc0ee0b6
commit
597e0259dc
@ -112,6 +112,11 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
eval_batches: A list of batches that form the evaluation set;
|
eval_batches: A list of batches that form the evaluation set;
|
||||||
list of batch-sized lists of indices corresponding to __getitem__
|
list of batch-sized lists of indices corresponding to __getitem__
|
||||||
of this class, thus it can be used directly as a batch sampler.
|
of this class, thus it can be used directly as a batch sampler.
|
||||||
|
eval_batch_index:
|
||||||
|
( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] )
|
||||||
|
A list of batches of frames described as (sequence_name, frame_idx)
|
||||||
|
that can form the evaluation set, `eval_batches` will be set from this.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
frame_annotations_type: ClassVar[
|
frame_annotations_type: ClassVar[
|
||||||
@ -147,6 +152,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
seed: int = 0
|
seed: int = 0
|
||||||
sort_frames: bool = False
|
sort_frames: bool = False
|
||||||
eval_batches: Any = None
|
eval_batches: Any = None
|
||||||
|
eval_batch_index: Any = None
|
||||||
# frame_annots: List[FrameAnnotsEntry] = field(init=False)
|
# frame_annots: List[FrameAnnotsEntry] = field(init=False)
|
||||||
# seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
|
# seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
|
||||||
|
|
||||||
@ -159,8 +165,20 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
|
|||||||
self._sort_frames()
|
self._sort_frames()
|
||||||
self._load_subset_lists()
|
self._load_subset_lists()
|
||||||
self._filter_db() # also computes sequence indices
|
self._filter_db() # also computes sequence indices
|
||||||
|
self._extract_and_set_eval_batches()
|
||||||
logger.info(str(self))
|
logger.info(str(self))
|
||||||
|
|
||||||
|
def _extract_and_set_eval_batches(self):
|
||||||
|
"""
|
||||||
|
Sets eval_batches based on input eval_batch_index.
|
||||||
|
"""
|
||||||
|
if self.eval_batch_index is not None:
|
||||||
|
if self.eval_batches is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Cannot define both eval_batch_index and eval_batches."
|
||||||
|
)
|
||||||
|
self.eval_batches = self.seq_frame_index_to_dataset_index()
|
||||||
|
|
||||||
def is_filtered(self):
|
def is_filtered(self):
|
||||||
"""
|
"""
|
||||||
Returns `True` in case the dataset has been filtered and thus some frame annotations
|
Returns `True` in case the dataset has been filtered and thus some frame annotations
|
||||||
|
@ -57,6 +57,7 @@ _CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
|
|||||||
_NEED_CONTROL: Tuple[str, ...] = (
|
_NEED_CONTROL: Tuple[str, ...] = (
|
||||||
"dataset_root",
|
"dataset_root",
|
||||||
"eval_batches",
|
"eval_batches",
|
||||||
|
"eval_batch_index",
|
||||||
"n_frames_per_sequence",
|
"n_frames_per_sequence",
|
||||||
"path_manager",
|
"path_manager",
|
||||||
"pick_sequence",
|
"pick_sequence",
|
||||||
@ -212,6 +213,10 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
]
|
]
|
||||||
# overwrite the restrict_sequence_name
|
# overwrite the restrict_sequence_name
|
||||||
restrict_sequence_name = [eval_sequence_name]
|
restrict_sequence_name = [eval_sequence_name]
|
||||||
|
if len(restrict_sequence_name) > 0:
|
||||||
|
eval_batch_index = [
|
||||||
|
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
|
||||||
|
]
|
||||||
|
|
||||||
dataset_type: Type[JsonIndexDataset] = registry.get(
|
dataset_type: Type[JsonIndexDataset] = registry.get(
|
||||||
JsonIndexDataset, self.dataset_class_type
|
JsonIndexDataset, self.dataset_class_type
|
||||||
@ -239,15 +244,9 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
n_frames_per_sequence=-1,
|
n_frames_per_sequence=-1,
|
||||||
subsets=set_names_mapping["test"],
|
subsets=set_names_mapping["test"],
|
||||||
pick_sequence=restrict_sequence_name,
|
pick_sequence=restrict_sequence_name,
|
||||||
|
eval_batch_index=eval_batch_index,
|
||||||
**common_kwargs,
|
**common_kwargs,
|
||||||
)
|
)
|
||||||
if len(restrict_sequence_name) > 0:
|
|
||||||
eval_batch_index = [
|
|
||||||
b for b in eval_batch_index if b[0][0] in restrict_sequence_name
|
|
||||||
]
|
|
||||||
test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index(
|
|
||||||
eval_batch_index
|
|
||||||
)
|
|
||||||
dataset_map = DatasetMap(
|
dataset_map = DatasetMap(
|
||||||
train=train_dataset, val=val_dataset, test=test_dataset
|
train=train_dataset, val=val_dataset, test=test_dataset
|
||||||
)
|
)
|
||||||
|
@ -11,6 +11,7 @@ import os
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, List, Optional, Type
|
from typing import Dict, List, Optional, Type
|
||||||
|
|
||||||
|
from omegaconf import DictConfig, open_dict
|
||||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
from pytorch3d.implicitron.dataset.dataset_map_provider import (
|
||||||
DatasetMap,
|
DatasetMap,
|
||||||
DatasetMapProviderBase,
|
DatasetMapProviderBase,
|
||||||
@ -269,6 +270,15 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
|
|||||||
train=train_dataset, val=val_dataset, test=test_dataset
|
train=train_dataset, val=val_dataset, test=test_dataset
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def dataset_tweak_args(cls, type, args: DictConfig) -> None:
|
||||||
|
"""
|
||||||
|
Called by get_default_args(JsonIndexDatasetMapProviderV2) to
|
||||||
|
not expose certain fields of each dataset class.
|
||||||
|
"""
|
||||||
|
with open_dict(args):
|
||||||
|
del args["eval_batch_index"]
|
||||||
|
|
||||||
def create_dataset(self):
|
def create_dataset(self):
|
||||||
# The dataset object is created inside `self.get_dataset_map`
|
# The dataset object is created inside `self.get_dataset_map`
|
||||||
pass
|
pass
|
||||||
|
Loading…
x
Reference in New Issue
Block a user