mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	CO3Dv2 multi-category extension
Summary: Allows loading of multiple categories. Multiple categories are provided in a comma-separated list of category names. Reviewed By: bottler, shapovalov Differential Revision: D40803297 fbshipit-source-id: 863938be3aa6ffefe9e563aede4a2e9e66aeeaa8
This commit is contained in:
		
							parent
							
								
									c54e048666
								
							
						
					
					
						commit
						e4a3298149
					
				@ -62,6 +62,7 @@ data_source_ImplicitronDataSource_args:
 | 
			
		||||
    test_on_train: false
 | 
			
		||||
    only_test_set: false
 | 
			
		||||
    load_eval_batches: true
 | 
			
		||||
    num_load_workers: 4
 | 
			
		||||
    n_known_frames_for_test: 0
 | 
			
		||||
    dataset_class_type: JsonIndexDataset
 | 
			
		||||
    path_manager_factory_class_type: PathManagerFactory
 | 
			
		||||
 | 
			
		||||
@ -9,6 +9,7 @@ from dataclasses import dataclass, field, fields
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    ClassVar,
 | 
			
		||||
    Dict,
 | 
			
		||||
    Iterable,
 | 
			
		||||
    Iterator,
 | 
			
		||||
    List,
 | 
			
		||||
@ -259,6 +260,12 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
 | 
			
		||||
        """
 | 
			
		||||
        raise ValueError("This dataset does not contain videos.")
 | 
			
		||||
 | 
			
		||||
    def join(self, other_datasets: Iterable["DatasetBase"]) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Joins the current dataset with a list of other datasets of the same type.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
 | 
			
		||||
    def get_eval_batches(self) -> Optional[List[List[int]]]:
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
@ -267,6 +274,22 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]):
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        return self._seq_to_idx.keys()
 | 
			
		||||
 | 
			
		||||
    def category_to_sequence_names(self) -> Dict[str, List[str]]:
 | 
			
		||||
        """
 | 
			
		||||
        Returns a dict mapping from each dataset category to a list of its
 | 
			
		||||
        sequence names.
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            category_to_sequence_names: Dict {category_i: [..., sequence_name_j, ...]}
 | 
			
		||||
        """
 | 
			
		||||
        c2seq = defaultdict(list)
 | 
			
		||||
        for sequence_name in self.sequence_names():
 | 
			
		||||
            first_frame_idx = next(self.sequence_indices_in_order(sequence_name))
 | 
			
		||||
            # crashes without overriding __getitem__
 | 
			
		||||
            sequence_category = self[first_frame_idx].sequence_category
 | 
			
		||||
            c2seq[sequence_category].append(sequence_name)
 | 
			
		||||
        return dict(c2seq)
 | 
			
		||||
 | 
			
		||||
    def sequence_frames_in_order(
 | 
			
		||||
        self, seq_name: str
 | 
			
		||||
    ) -> Iterator[Tuple[float, int, int]]:
 | 
			
		||||
 | 
			
		||||
@ -7,7 +7,7 @@
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Iterator, Optional
 | 
			
		||||
from typing import Iterable, Iterator, Optional
 | 
			
		||||
 | 
			
		||||
from iopath.common.file_io import PathManager
 | 
			
		||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
 | 
			
		||||
@ -51,6 +51,34 @@ class DatasetMap:
 | 
			
		||||
        if self.test is not None:
 | 
			
		||||
            yield self.test
 | 
			
		||||
 | 
			
		||||
    def join(self, other_dataset_maps: Iterable["DatasetMap"]) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Joins the current DatasetMap with other dataset maps from the input list.
 | 
			
		||||
 | 
			
		||||
        For each subset of each dataset map (train/val/test), the function
 | 
			
		||||
        omits joining the subsets that are None.
 | 
			
		||||
 | 
			
		||||
        Note the train/val/test datasets of the current dataset map will be
 | 
			
		||||
        modified in-place.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            other_dataset_maps: The list of dataset maps to be joined into the
 | 
			
		||||
                current dataset map.
 | 
			
		||||
        """
 | 
			
		||||
        for set_ in ["train", "val", "test"]:
 | 
			
		||||
            dataset_list = [
 | 
			
		||||
                getattr(self, set_),
 | 
			
		||||
                *[getattr(dmap, set_) for dmap in other_dataset_maps],
 | 
			
		||||
            ]
 | 
			
		||||
            dataset_list = [d for d in dataset_list if d is not None]
 | 
			
		||||
            if len(dataset_list) == 0:
 | 
			
		||||
                setattr(self, set_, None)
 | 
			
		||||
                continue
 | 
			
		||||
            d0 = dataset_list[0]
 | 
			
		||||
            if len(dataset_list) > 1:
 | 
			
		||||
                d0.join(dataset_list[1:])
 | 
			
		||||
            setattr(self, set_, d0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DatasetMapProviderBase(ReplaceableBase):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -19,6 +19,8 @@ from pathlib import Path
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    ClassVar,
 | 
			
		||||
    Dict,
 | 
			
		||||
    Iterable,
 | 
			
		||||
    List,
 | 
			
		||||
    Optional,
 | 
			
		||||
    Sequence,
 | 
			
		||||
@ -188,7 +190,44 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
                self.eval_batch_index
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def is_filtered(self):
 | 
			
		||||
    def join(self, other_datasets: Iterable[DatasetBase]) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        Join the dataset with other JsonIndexDataset objects.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            other_datasets: A list of JsonIndexDataset objects to be joined
 | 
			
		||||
                into the current dataset.
 | 
			
		||||
        """
 | 
			
		||||
        if not all(isinstance(d, JsonIndexDataset) for d in other_datasets):
 | 
			
		||||
            raise ValueError("This function can only join a list of JsonIndexDataset")
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.frame_annots.extend([fa for d in other_datasets for fa in d.frame_annots])
 | 
			
		||||
        # pyre-ignore[16]
 | 
			
		||||
        self.seq_annots.update(
 | 
			
		||||
            # https://gist.github.com/treyhunner/f35292e676efa0be1728
 | 
			
		||||
            functools.reduce(
 | 
			
		||||
                lambda a, b: {**a, **b},
 | 
			
		||||
                [d.seq_annots for d in other_datasets],  # pyre-ignore[16]
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
        all_eval_batches = [
 | 
			
		||||
            self.eval_batches,
 | 
			
		||||
            # pyre-ignore
 | 
			
		||||
            *[d.eval_batches for d in other_datasets],
 | 
			
		||||
        ]
 | 
			
		||||
        if not (
 | 
			
		||||
            all(ba is None for ba in all_eval_batches)
 | 
			
		||||
            or all(ba is not None for ba in all_eval_batches)
 | 
			
		||||
        ):
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                "When joining datasets, either all joined datasets have to have their"
 | 
			
		||||
                " eval_batches defined, or all should have their eval batches undefined."
 | 
			
		||||
            )
 | 
			
		||||
        if self.eval_batches is not None:
 | 
			
		||||
            self.eval_batches = sum(all_eval_batches, [])
 | 
			
		||||
        self._invalidate_indexes(filter_seq_annots=True)
 | 
			
		||||
 | 
			
		||||
    def is_filtered(self) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        Returns `True` in case the dataset has been filtered and thus some frame annotations
 | 
			
		||||
        stored on the disk might be missing in the dataset object.
 | 
			
		||||
@ -211,6 +250,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
        seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
 | 
			
		||||
        allow_missing_indices: bool = False,
 | 
			
		||||
        remove_missing_indices: bool = False,
 | 
			
		||||
        suppress_missing_index_warning: bool = True,
 | 
			
		||||
    ) -> List[List[Union[Optional[int], int]]]:
 | 
			
		||||
        """
 | 
			
		||||
        Obtain indices into the dataset object given a list of frame ids.
 | 
			
		||||
@ -228,6 +268,11 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
                If `False`, returns `None` in place of `seq_frame_index` entries that
 | 
			
		||||
                are not present in the dataset.
 | 
			
		||||
                If `True` removes missing indices from the returned indices.
 | 
			
		||||
            suppress_missing_index_warning:
 | 
			
		||||
                Active if `allow_missing_indices==True`. Suppressess a warning message
 | 
			
		||||
                in case an entry from `seq_frame_index` is missing in the dataset
 | 
			
		||||
                (expected in certain cases - e.g. when setting
 | 
			
		||||
                `self.remove_empty_masks=True`).
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`.
 | 
			
		||||
@ -254,7 +299,8 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
                )
 | 
			
		||||
                if not allow_missing_indices:
 | 
			
		||||
                    raise IndexError(msg)
 | 
			
		||||
                warnings.warn(msg)
 | 
			
		||||
                if not suppress_missing_index_warning:
 | 
			
		||||
                    warnings.warn(msg)
 | 
			
		||||
                return idx
 | 
			
		||||
            if path is not None:
 | 
			
		||||
                # Check that the loaded frame path is consistent
 | 
			
		||||
@ -288,6 +334,21 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
        frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]],
 | 
			
		||||
        allow_missing_indices: bool = True,
 | 
			
		||||
    ) -> "JsonIndexDataset":
 | 
			
		||||
        """
 | 
			
		||||
        Generate a dataset subset given the list of frames specified in `frame_index`.
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            frame_index: The list of frame indentifiers (as stored in the metadata)
 | 
			
		||||
                specified as `List[Tuple[sequence_name:str, frame_number:int]]`. Optionally,
 | 
			
		||||
                Image paths relative to the dataset_root can be stored specified as well:
 | 
			
		||||
                `List[Tuple[sequence_name:str, frame_number:int, image_path:str]]`,
 | 
			
		||||
                in the latter case, if imaga_path do not match the stored paths, an error
 | 
			
		||||
                is raised.
 | 
			
		||||
            allow_missing_indices: If `False`, throws an IndexError upon reaching the first
 | 
			
		||||
                entry from `frame_index` which is missing in the dataset.
 | 
			
		||||
                Otherwise, generates a subset consisting of frames entries that actually
 | 
			
		||||
                exist in the dataset.
 | 
			
		||||
        """
 | 
			
		||||
        # Get the indices into the frame annots.
 | 
			
		||||
        dataset_indices = self.seq_frame_index_to_dataset_index(
 | 
			
		||||
            [frame_index],
 | 
			
		||||
@ -838,6 +899,13 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase):
 | 
			
		||||
            )
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    def category_to_sequence_names(self) -> Dict[str, List[str]]:
 | 
			
		||||
        c2seq = defaultdict(list)
 | 
			
		||||
        # pyre-ignore
 | 
			
		||||
        for sequence_name, sa in self.seq_annots.items():
 | 
			
		||||
            c2seq[sa.category].append(sequence_name)
 | 
			
		||||
        return dict(c2seq)
 | 
			
		||||
 | 
			
		||||
    def get_eval_batches(self) -> Optional[List[List[int]]]:
 | 
			
		||||
        return self.eval_batches
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@
 | 
			
		||||
import copy
 | 
			
		||||
import json
 | 
			
		||||
import logging
 | 
			
		||||
import multiprocessing
 | 
			
		||||
import os
 | 
			
		||||
import warnings
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
@ -30,6 +31,7 @@ from pytorch3d.implicitron.tools.config import (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from pytorch3d.renderer.cameras import CamerasBase
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "")
 | 
			
		||||
@ -147,7 +149,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
    (test frames can repeat across batches).
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        category: The object category of the dataset.
 | 
			
		||||
        category: Dataset categories to load expressed as a string of comma-separated
 | 
			
		||||
            category names (e.g. `"apple,car,orange"`).
 | 
			
		||||
        subset_name: The name of the dataset subset. For CO3Dv2, these include
 | 
			
		||||
            e.g. "manyview_dev_0", "fewview_test", ...
 | 
			
		||||
        dataset_root: The root folder of the dataset.
 | 
			
		||||
@ -173,6 +176,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
    test_on_train: bool = False
 | 
			
		||||
    only_test_set: bool = False
 | 
			
		||||
    load_eval_batches: bool = True
 | 
			
		||||
    num_load_workers: int = 4
 | 
			
		||||
 | 
			
		||||
    n_known_frames_for_test: int = 0
 | 
			
		||||
 | 
			
		||||
@ -189,11 +193,33 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
        if self.only_test_set and self.test_on_train:
 | 
			
		||||
            raise ValueError("Cannot have only_test_set and test_on_train")
 | 
			
		||||
 | 
			
		||||
        frame_file = os.path.join(
 | 
			
		||||
            self.dataset_root, self.category, "frame_annotations.jgz"
 | 
			
		||||
        )
 | 
			
		||||
        if "," in self.category:
 | 
			
		||||
            # a comma-separated list of categories to load
 | 
			
		||||
            categories = [c.strip() for c in self.category.split(",")]
 | 
			
		||||
            logger.info(f"Loading a list of categories: {str(categories)}.")
 | 
			
		||||
            with multiprocessing.Pool(
 | 
			
		||||
                processes=min(self.num_load_workers, len(categories))
 | 
			
		||||
            ) as pool:
 | 
			
		||||
                category_dataset_maps = list(
 | 
			
		||||
                    tqdm(
 | 
			
		||||
                        pool.imap(self._load_category, categories),
 | 
			
		||||
                        total=len(categories),
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
            dataset_map = category_dataset_maps[0]
 | 
			
		||||
            dataset_map.join(category_dataset_maps[1:])
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            # one category to load
 | 
			
		||||
            dataset_map = self._load_category(self.category)
 | 
			
		||||
 | 
			
		||||
        self.dataset_map = dataset_map
 | 
			
		||||
 | 
			
		||||
    def _load_category(self, category: str) -> DatasetMap:
 | 
			
		||||
 | 
			
		||||
        frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz")
 | 
			
		||||
        sequence_file = os.path.join(
 | 
			
		||||
            self.dataset_root, self.category, "sequence_annotations.jgz"
 | 
			
		||||
            self.dataset_root, category, "sequence_annotations.jgz"
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        path_manager = self.path_manager_factory.get()
 | 
			
		||||
@ -232,7 +258,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
 | 
			
		||||
        dataset = dataset_type(**common_dataset_kwargs)
 | 
			
		||||
 | 
			
		||||
        available_subset_names = self._get_available_subset_names()
 | 
			
		||||
        available_subset_names = self._get_available_subset_names(category)
 | 
			
		||||
        logger.debug(f"Available subset names: {str(available_subset_names)}.")
 | 
			
		||||
        if self.subset_name not in available_subset_names:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
@ -242,20 +268,20 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
 | 
			
		||||
        # load the list of train/val/test frames
 | 
			
		||||
        subset_mapping = self._load_annotation_json(
 | 
			
		||||
            os.path.join(
 | 
			
		||||
                self.category, "set_lists", f"set_lists_{self.subset_name}.json"
 | 
			
		||||
            )
 | 
			
		||||
            os.path.join(category, "set_lists", f"set_lists_{self.subset_name}.json")
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # load the evaluation batches
 | 
			
		||||
        if self.load_eval_batches:
 | 
			
		||||
            eval_batch_index = self._load_annotation_json(
 | 
			
		||||
                os.path.join(
 | 
			
		||||
                    self.category,
 | 
			
		||||
                    category,
 | 
			
		||||
                    "eval_batches",
 | 
			
		||||
                    f"eval_batches_{self.subset_name}.json",
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        else:
 | 
			
		||||
            eval_batch_index = None
 | 
			
		||||
 | 
			
		||||
        train_dataset = None
 | 
			
		||||
        if not self.only_test_set:
 | 
			
		||||
@ -313,9 +339,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
                    )
 | 
			
		||||
                logger.info(f"# eval batches: {len(test_dataset.eval_batches)}")
 | 
			
		||||
 | 
			
		||||
        self.dataset_map = DatasetMap(
 | 
			
		||||
            train=train_dataset, val=val_dataset, test=test_dataset
 | 
			
		||||
        )
 | 
			
		||||
        return DatasetMap(train=train_dataset, val=val_dataset, test=test_dataset)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def dataset_tweak_args(cls, type, args: DictConfig) -> None:
 | 
			
		||||
@ -381,10 +405,10 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase):  # pyre-ignore [13]
 | 
			
		||||
            data = json.load(f)
 | 
			
		||||
        return data
 | 
			
		||||
 | 
			
		||||
    def _get_available_subset_names(self):
 | 
			
		||||
    def _get_available_subset_names(self, category: str):
 | 
			
		||||
        return get_available_subset_names(
 | 
			
		||||
            self.dataset_root,
 | 
			
		||||
            self.category,
 | 
			
		||||
            category,
 | 
			
		||||
            path_manager=self.path_manager_factory.get(),
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -6,8 +6,9 @@
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import warnings
 | 
			
		||||
from collections import Counter
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from typing import Iterable, Iterator, List, Sequence, Tuple
 | 
			
		||||
from typing import Dict, Iterable, Iterator, List, Sequence, Tuple
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from torch.utils.data.sampler import Sampler
 | 
			
		||||
@ -42,8 +43,17 @@ class SceneBatchSampler(Sampler[List[int]]):
 | 
			
		||||
    # same but for timestamps if they are available
 | 
			
		||||
    consecutive_frames_max_gap_seconds: float = 0.1
 | 
			
		||||
 | 
			
		||||
    # if True, the sampler first reads from the dataset the mapping between
 | 
			
		||||
    # sequence names and their categories.
 | 
			
		||||
    # During batch sampling, the sampler ensures uniform distribution over the categories
 | 
			
		||||
    # of the sampled sequences.
 | 
			
		||||
    category_aware: bool = True
 | 
			
		||||
 | 
			
		||||
    seq_names: List[str] = field(init=False)
 | 
			
		||||
 | 
			
		||||
    category_to_sequence_names: Dict[str, List[str]] = field(init=False)
 | 
			
		||||
    categories: List[str] = field(init=False)
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self) -> None:
 | 
			
		||||
        if self.batch_size <= 0:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
@ -56,6 +66,10 @@ class SceneBatchSampler(Sampler[List[int]]):
 | 
			
		||||
 | 
			
		||||
        self.seq_names = list(self.dataset.sequence_names())
 | 
			
		||||
 | 
			
		||||
        if self.category_aware:
 | 
			
		||||
            self.category_to_sequence_names = self.dataset.category_to_sequence_names()
 | 
			
		||||
            self.categories = list(self.category_to_sequence_names.keys())
 | 
			
		||||
 | 
			
		||||
    def __len__(self) -> int:
 | 
			
		||||
        return self.num_batches
 | 
			
		||||
 | 
			
		||||
@ -67,7 +81,25 @@ class SceneBatchSampler(Sampler[List[int]]):
 | 
			
		||||
    def _sample_batch(self, batch_idx) -> List[int]:
 | 
			
		||||
        n_per_seq = np.random.choice(self.images_per_seq_options)
 | 
			
		||||
        n_seqs = -(-self.batch_size // n_per_seq)  # round up
 | 
			
		||||
        chosen_seq = _capped_random_choice(self.seq_names, n_seqs, replace=False)
 | 
			
		||||
 | 
			
		||||
        if self.category_aware:
 | 
			
		||||
            # first sample categories at random, these can be repeated in the batch
 | 
			
		||||
            chosen_cat = _capped_random_choice(self.categories, n_seqs, replace=True)
 | 
			
		||||
            # then randomly sample a set of unique sequences within each category
 | 
			
		||||
            chosen_seq = []
 | 
			
		||||
            for cat, n_per_category in Counter(chosen_cat).items():
 | 
			
		||||
                category_chosen_seq = _capped_random_choice(
 | 
			
		||||
                    self.category_to_sequence_names[cat],
 | 
			
		||||
                    n_per_category,
 | 
			
		||||
                    replace=False,
 | 
			
		||||
                )
 | 
			
		||||
                chosen_seq.extend([str(s) for s in category_chosen_seq])
 | 
			
		||||
        else:
 | 
			
		||||
            chosen_seq = _capped_random_choice(
 | 
			
		||||
                self.seq_names,
 | 
			
		||||
                n_seqs,
 | 
			
		||||
                replace=False,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        if self.sample_consecutive_frames:
 | 
			
		||||
            frame_idx = []
 | 
			
		||||
 | 
			
		||||
@ -49,6 +49,7 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
 | 
			
		||||
  test_on_train: false
 | 
			
		||||
  only_test_set: false
 | 
			
		||||
  load_eval_batches: true
 | 
			
		||||
  num_load_workers: 4
 | 
			
		||||
  n_known_frames_for_test: 0
 | 
			
		||||
  dataset_class_type: JsonIndexDataset
 | 
			
		||||
  path_manager_factory_class_type: PathManagerFactory
 | 
			
		||||
 | 
			
		||||
@ -11,17 +11,20 @@ from dataclasses import dataclass
 | 
			
		||||
from itertools import product
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
 | 
			
		||||
    DoublePoolBatchSampler,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase
 | 
			
		||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
 | 
			
		||||
from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class MockFrameAnnotation:
 | 
			
		||||
    frame_number: int
 | 
			
		||||
    sequence_name: str = "sequence"
 | 
			
		||||
    frame_timestamp: float = 0.0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -41,6 +44,9 @@ class MockDataset(DatasetBase):
 | 
			
		||||
        self.frame_annots = [
 | 
			
		||||
            {"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq
 | 
			
		||||
        ]
 | 
			
		||||
        for seq_name, idx in self._seq_to_idx.items():
 | 
			
		||||
            for i in idx:
 | 
			
		||||
                self.frame_annots[i]["frame_annotation"].sequence_name = seq_name
 | 
			
		||||
 | 
			
		||||
    def get_frame_numbers_and_timestamps(self, idxs):
 | 
			
		||||
        out = []
 | 
			
		||||
@ -51,6 +57,16 @@ class MockDataset(DatasetBase):
 | 
			
		||||
            )
 | 
			
		||||
        return out
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, index: int):
 | 
			
		||||
        fa = self.frame_annots[index]["frame_annotation"]
 | 
			
		||||
        fd = FrameData(
 | 
			
		||||
            sequence_name=fa.sequence_name,
 | 
			
		||||
            sequence_category="default_category",
 | 
			
		||||
            frame_number=torch.LongTensor([fa.frame_number]),
 | 
			
		||||
            frame_timestamp=torch.LongTensor([fa.frame_timestamp]),
 | 
			
		||||
        )
 | 
			
		||||
        return fd
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestSceneBatchSampler(unittest.TestCase):
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
 | 
			
		||||
@ -41,22 +41,73 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
 | 
			
		||||
        categories = ["A", "B"]
 | 
			
		||||
        subset_name = "test"
 | 
			
		||||
        eval_batch_size = 5
 | 
			
		||||
        n_frames = 8 * 3
 | 
			
		||||
        n_sequences = 5
 | 
			
		||||
        n_eval_batches = 10
 | 
			
		||||
        with tempfile.TemporaryDirectory() as tmpd:
 | 
			
		||||
            _make_random_json_dataset_map_provider_v2_data(
 | 
			
		||||
                tmpd,
 | 
			
		||||
                categories,
 | 
			
		||||
                eval_batch_size=eval_batch_size,
 | 
			
		||||
                n_frames=n_frames,
 | 
			
		||||
                n_sequences=n_sequences,
 | 
			
		||||
                n_eval_batches=n_eval_batches,
 | 
			
		||||
            )
 | 
			
		||||
            for n_known_frames_for_test in [0, 2]:
 | 
			
		||||
                for category in categories:
 | 
			
		||||
                    dataset_provider = JsonIndexDatasetMapProviderV2(
 | 
			
		||||
                dataset_providers = {
 | 
			
		||||
                    category: JsonIndexDatasetMapProviderV2(
 | 
			
		||||
                        category=category,
 | 
			
		||||
                        subset_name="test",
 | 
			
		||||
                        dataset_root=tmpd,
 | 
			
		||||
                        n_known_frames_for_test=n_known_frames_for_test,
 | 
			
		||||
                    )
 | 
			
		||||
                    for category in [*categories, ",".join(sorted(categories))]
 | 
			
		||||
                }
 | 
			
		||||
                for category, dataset_provider in dataset_providers.items():
 | 
			
		||||
                    dataset_map = dataset_provider.get_dataset_map()
 | 
			
		||||
                    for set_ in ["train", "val", "test"]:
 | 
			
		||||
                        dataset = getattr(dataset_map, set_)
 | 
			
		||||
 | 
			
		||||
                        cat2seq = dataset.category_to_sequence_names()
 | 
			
		||||
                        self.assertEqual(",".join(sorted(cat2seq.keys())), category)
 | 
			
		||||
 | 
			
		||||
                        if not (n_known_frames_for_test != 0 and set_ == "test"):
 | 
			
		||||
                            # check the lengths only in case we do not have the
 | 
			
		||||
                            # n_known_frames_for_test set
 | 
			
		||||
                            expected_dataset_len = n_frames * n_sequences // 3
 | 
			
		||||
                            if "," in category:
 | 
			
		||||
                                # multicategory json index dataset, sum the lengths of
 | 
			
		||||
                                # category-specific ones
 | 
			
		||||
                                expected_dataset_len = sum(
 | 
			
		||||
                                    len(
 | 
			
		||||
                                        getattr(
 | 
			
		||||
                                            dataset_providers[c].get_dataset_map(), set_
 | 
			
		||||
                                        )
 | 
			
		||||
                                    )
 | 
			
		||||
                                    for c in categories
 | 
			
		||||
                                )
 | 
			
		||||
                                self.assertEqual(
 | 
			
		||||
                                    sum(len(s) for s in cat2seq.values()),
 | 
			
		||||
                                    n_sequences * len(categories),
 | 
			
		||||
                                )
 | 
			
		||||
                                self.assertEqual(len(cat2seq), len(categories))
 | 
			
		||||
                            else:
 | 
			
		||||
                                self.assertEqual(
 | 
			
		||||
                                    len(cat2seq[category]),
 | 
			
		||||
                                    n_sequences,
 | 
			
		||||
                                )
 | 
			
		||||
                                self.assertEqual(len(cat2seq), 1)
 | 
			
		||||
                            self.assertEqual(len(dataset), expected_dataset_len)
 | 
			
		||||
 | 
			
		||||
                        if set_ == "test":
 | 
			
		||||
                            # check the number of eval batches
 | 
			
		||||
                            expected_n_eval_batches = n_eval_batches
 | 
			
		||||
                            if "," in category:
 | 
			
		||||
                                expected_n_eval_batches *= len(categories)
 | 
			
		||||
                            self.assertTrue(
 | 
			
		||||
                                len(dataset.get_eval_batches())
 | 
			
		||||
                                == expected_n_eval_batches
 | 
			
		||||
                            )
 | 
			
		||||
                        if set_ in ["train", "val"]:
 | 
			
		||||
                            dataloader = torch.utils.data.DataLoader(
 | 
			
		||||
                                getattr(dataset_map, set_),
 | 
			
		||||
@ -80,6 +131,7 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase):
 | 
			
		||||
                        dataset_provider.get_category_to_subset_name_list()
 | 
			
		||||
                    )
 | 
			
		||||
                    category_to_subset_list_ = {c: [subset_name] for c in categories}
 | 
			
		||||
 | 
			
		||||
                    self.assertTrue(category_to_subset_list == category_to_subset_list_)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -88,6 +140,7 @@ def _make_random_json_dataset_map_provider_v2_data(
 | 
			
		||||
    categories: List[str],
 | 
			
		||||
    n_frames: int = 8,
 | 
			
		||||
    n_sequences: int = 5,
 | 
			
		||||
    n_eval_batches: int = 10,
 | 
			
		||||
    H: int = 50,
 | 
			
		||||
    W: int = 30,
 | 
			
		||||
    subset_name: str = "test",
 | 
			
		||||
@ -100,7 +153,7 @@ def _make_random_json_dataset_map_provider_v2_data(
 | 
			
		||||
        sequence_annotations = []
 | 
			
		||||
        frame_index = []
 | 
			
		||||
        for seq_i in range(n_sequences):
 | 
			
		||||
            seq_name = str(seq_i)
 | 
			
		||||
            seq_name = category + str(seq_i)
 | 
			
		||||
            for i in range(n_frames):
 | 
			
		||||
                # generate and store image
 | 
			
		||||
                imdir = os.path.join(root, category, seq_name, "images")
 | 
			
		||||
@ -165,7 +218,8 @@ def _make_random_json_dataset_map_provider_v2_data(
 | 
			
		||||
            json.dump(set_list, f)
 | 
			
		||||
 | 
			
		||||
        eval_batches = [
 | 
			
		||||
            random.sample(test_frame_index, eval_batch_size) for _ in range(10)
 | 
			
		||||
            random.sample(test_frame_index, eval_batch_size)
 | 
			
		||||
            for _ in range(n_eval_batches)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        eval_b_dir = os.path.join(root, category, "eval_batches")
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user