mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	Made eval_batches be set inside the __init__
Summary: Made eval_batches be set in call to `__init__` not after the construction as they were before Reviewed By: bottler Differential Revision: D38275943 fbshipit-source-id: 32737401d1ddd16c284e1851b7a91f8b041c406f
This commit is contained in:
		
							parent
							
								
									80fc0ee0b6
								
							
						
					
					
						commit
						597e0259dc
					
				@ -112,6 +112,11 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
        eval_batches: A list of batches that form the evaluation set;
 | 
			
		||||
            list of batch-sized lists of indices corresponding to __getitem__
 | 
			
		||||
            of this class, thus it can be used directly as a batch sampler.
 | 
			
		||||
        eval_batch_index:
 | 
			
		||||
            ( Optional[List[List[Union[Tuple[str, int, str], Tuple[str, int]]]] )
 | 
			
		||||
            A list of batches of frames described as (sequence_name, frame_idx)
 | 
			
		||||
            that can form the evaluation set, `eval_batches` will be set from this.
 | 
			
		||||
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    frame_annotations_type: ClassVar[
 | 
			
		||||
@ -147,6 +152,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
    seed: int = 0
 | 
			
		||||
    sort_frames: bool = False
 | 
			
		||||
    eval_batches: Any = None
 | 
			
		||||
    eval_batch_index: Any = None
 | 
			
		||||
    # frame_annots: List[FrameAnnotsEntry] = field(init=False)
 | 
			
		||||
    # seq_annots: Dict[str, types.SequenceAnnotation] = field(init=False)
 | 
			
		||||
 | 
			
		||||
@ -159,8 +165,20 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
            self._sort_frames()
 | 
			
		||||
        self._load_subset_lists()
 | 
			
		||||
        self._filter_db()  # also computes sequence indices
 | 
			
		||||
        self._extract_and_set_eval_batches()
 | 
			
		||||
        logger.info(str(self))
 | 
			
		||||
 | 
			
		||||
    def _extract_and_set_eval_batches(self):
 | 
			
		||||
        """
 | 
			
		||||
        Sets eval_batches based on input eval_batch_index.
 | 
			
		||||
        """
 | 
			
		||||
        if self.eval_batch_index is not None:
 | 
			
		||||
            if self.eval_batches is not None:
 | 
			
		||||
                raise ValueError(
 | 
			
		||||
                    "Cannot define both eval_batch_index and eval_batches."
 | 
			
		||||
                )
 | 
			
		||||
            self.eval_batches = self.seq_frame_index_to_dataset_index()
 | 
			
		||||
 | 
			
		||||
    def is_filtered(self):
 | 
			
		||||
        """
 | 
			
		||||
        Returns `True` in case the dataset has been filtered and thus some frame annotations
 | 
			
		||||
 | 
			
		||||
@ -57,6 +57,7 @@ _CO3D_DATASET_ROOT: str = os.getenv("CO3D_DATASET_ROOT", "")
 | 
			
		||||
_NEED_CONTROL: Tuple[str, ...] = (
 | 
			
		||||
    "dataset_root",
 | 
			
		||||
    "eval_batches",
 | 
			
		||||
    "eval_batch_index",
 | 
			
		||||
    "n_frames_per_sequence",
 | 
			
		||||
    "path_manager",
 | 
			
		||||
    "pick_sequence",
 | 
			
		||||
@ -212,6 +213,10 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
            ]
 | 
			
		||||
            # overwrite the restrict_sequence_name
 | 
			
		||||
            restrict_sequence_name = [eval_sequence_name]
 | 
			
		||||
        if len(restrict_sequence_name) > 0:
 | 
			
		||||
            eval_batch_index = [
 | 
			
		||||
                b for b in eval_batch_index if b[0][0] in restrict_sequence_name
 | 
			
		||||
            ]
 | 
			
		||||
 | 
			
		||||
        dataset_type: Type[JsonIndexDataset] = registry.get(
 | 
			
		||||
            JsonIndexDataset, self.dataset_class_type
 | 
			
		||||
@ -239,15 +244,9 @@ class JsonIndexDatasetMapProvider(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
                n_frames_per_sequence=-1,
 | 
			
		||||
                subsets=set_names_mapping["test"],
 | 
			
		||||
                pick_sequence=restrict_sequence_name,
 | 
			
		||||
                eval_batch_index=eval_batch_index,
 | 
			
		||||
                **common_kwargs,
 | 
			
		||||
            )
 | 
			
		||||
            if len(restrict_sequence_name) > 0:
 | 
			
		||||
                eval_batch_index = [
 | 
			
		||||
                    b for b in eval_batch_index if b[0][0] in restrict_sequence_name
 | 
			
		||||
                ]
 | 
			
		||||
            test_dataset.eval_batches = test_dataset.seq_frame_index_to_dataset_index(
 | 
			
		||||
                eval_batch_index
 | 
			
		||||
            )
 | 
			
		||||
        dataset_map = DatasetMap(
 | 
			
		||||
            train=train_dataset, val=val_dataset, test=test_dataset
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
@ -11,6 +11,7 @@ import os
 | 
			
		||||
import warnings
 | 
			
		||||
from typing import Dict, List, Optional, Type
 | 
			
		||||
 | 
			
		||||
from omegaconf import DictConfig, open_dict
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_map_provider import (
 | 
			
		||||
    DatasetMap,
 | 
			
		||||
    DatasetMapProviderBase,
 | 
			
		||||
@ -269,6 +270,15 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
            train=train_dataset, val=val_dataset, test=test_dataset
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def dataset_tweak_args(cls, type, args: DictConfig) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Called by get_default_args(JsonIndexDatasetMapProviderV2) to
 | 
			
		||||
        not expose certain fields of each dataset class.
 | 
			
		||||
        """
 | 
			
		||||
        with open_dict(args):
 | 
			
		||||
            del args["eval_batch_index"]
 | 
			
		||||
 | 
			
		||||
    def create_dataset(self):
 | 
			
		||||
        # The dataset object is created inside `self.get_dataset_map`
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user