From e4a329814978934b2fda8fb92c2c14ebb42fa0b2 Mon Sep 17 00:00:00 2001 From: David Novotny Date: Wed, 2 Nov 2022 13:55:25 -0700 Subject: [PATCH] CO3Dv2 multi-category extension Summary: Allows loading of multiple categories. Multiple categories are provided in a comma-separated list of category names. Reviewed By: bottler, shapovalov Differential Revision: D40803297 fbshipit-source-id: 863938be3aa6ffefe9e563aede4a2e9e66aeeaa8 --- .../implicitron_trainer/tests/experiment.yaml | 1 + pytorch3d/implicitron/dataset/dataset_base.py | 23 ++++++ .../dataset/dataset_map_provider.py | 30 +++++++- .../implicitron/dataset/json_index_dataset.py | 72 ++++++++++++++++++- .../json_index_dataset_map_provider_v2.py | 54 ++++++++++---- .../dataset/scene_batch_sampler.py | 36 +++++++++- tests/implicitron/data/data_source.yaml | 1 + tests/implicitron/test_batch_sampler.py | 18 ++++- .../test_json_index_dataset_provider_v2.py | 62 ++++++++++++++-- 9 files changed, 272 insertions(+), 25 deletions(-) diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index db364ff3..613ca168 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -62,6 +62,7 @@ data_source_ImplicitronDataSource_args: test_on_train: false only_test_set: false load_eval_batches: true + num_load_workers: 4 n_known_frames_for_test: 0 dataset_class_type: JsonIndexDataset path_manager_factory_class_type: PathManagerFactory diff --git a/pytorch3d/implicitron/dataset/dataset_base.py b/pytorch3d/implicitron/dataset/dataset_base.py index f3d8d615..802d04e3 100644 --- a/pytorch3d/implicitron/dataset/dataset_base.py +++ b/pytorch3d/implicitron/dataset/dataset_base.py @@ -9,6 +9,7 @@ from dataclasses import dataclass, field, fields from typing import ( Any, ClassVar, + Dict, Iterable, Iterator, List, @@ -259,6 +260,12 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]): """ raise ValueError("This dataset does not contain videos.") + def join(self, other_datasets: Iterable["DatasetBase"]) -> None: + """ + Joins the current dataset with a list of other datasets of the same type. + """ + raise NotImplementedError() + def get_eval_batches(self) -> Optional[List[List[int]]]: return None @@ -267,6 +274,22 @@ class DatasetBase(_GenericWorkaround, torch.utils.data.Dataset[FrameData]): # pyre-ignore[16] return self._seq_to_idx.keys() + def category_to_sequence_names(self) -> Dict[str, List[str]]: + """ + Returns a dict mapping from each dataset category to a list of its + sequence names. + + Returns: + category_to_sequence_names: Dict {category_i: [..., sequence_name_j, ...]} + """ + c2seq = defaultdict(list) + for sequence_name in self.sequence_names(): + first_frame_idx = next(self.sequence_indices_in_order(sequence_name)) + # crashes without overriding __getitem__ + sequence_category = self[first_frame_idx].sequence_category + c2seq[sequence_category].append(sequence_name) + return dict(c2seq) + def sequence_frames_in_order( self, seq_name: str ) -> Iterator[Tuple[float, int, int]]: diff --git a/pytorch3d/implicitron/dataset/dataset_map_provider.py b/pytorch3d/implicitron/dataset/dataset_map_provider.py index a1f62761..17569e52 100644 --- a/pytorch3d/implicitron/dataset/dataset_map_provider.py +++ b/pytorch3d/implicitron/dataset/dataset_map_provider.py @@ -7,7 +7,7 @@ import logging import os from dataclasses import dataclass -from typing import Iterator, Optional +from typing import Iterable, Iterator, Optional from iopath.common.file_io import PathManager from pytorch3d.implicitron.tools.config import registry, ReplaceableBase @@ -51,6 +51,34 @@ class DatasetMap: if self.test is not None: yield self.test + def join(self, other_dataset_maps: Iterable["DatasetMap"]) -> None: + """ + Joins the current DatasetMap with other dataset maps from the input list. + + For each subset of each dataset map (train/val/test), the function + omits joining the subsets that are None. + + Note the train/val/test datasets of the current dataset map will be + modified in-place. + + Args: + other_dataset_maps: The list of dataset maps to be joined into the + current dataset map. + """ + for set_ in ["train", "val", "test"]: + dataset_list = [ + getattr(self, set_), + *[getattr(dmap, set_) for dmap in other_dataset_maps], + ] + dataset_list = [d for d in dataset_list if d is not None] + if len(dataset_list) == 0: + setattr(self, set_, None) + continue + d0 = dataset_list[0] + if len(dataset_list) > 1: + d0.join(dataset_list[1:]) + setattr(self, set_, d0) + class DatasetMapProviderBase(ReplaceableBase): """ diff --git a/pytorch3d/implicitron/dataset/json_index_dataset.py b/pytorch3d/implicitron/dataset/json_index_dataset.py index 6dfccccc..2fdab768 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset.py @@ -19,6 +19,8 @@ from pathlib import Path from typing import ( Any, ClassVar, + Dict, + Iterable, List, Optional, Sequence, @@ -188,7 +190,44 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): self.eval_batch_index ) - def is_filtered(self): + def join(self, other_datasets: Iterable[DatasetBase]) -> None: + """ + Join the dataset with other JsonIndexDataset objects. + + Args: + other_datasets: A list of JsonIndexDataset objects to be joined + into the current dataset. + """ + if not all(isinstance(d, JsonIndexDataset) for d in other_datasets): + raise ValueError("This function can only join a list of JsonIndexDataset") + # pyre-ignore[16] + self.frame_annots.extend([fa for d in other_datasets for fa in d.frame_annots]) + # pyre-ignore[16] + self.seq_annots.update( + # https://gist.github.com/treyhunner/f35292e676efa0be1728 + functools.reduce( + lambda a, b: {**a, **b}, + [d.seq_annots for d in other_datasets], # pyre-ignore[16] + ) + ) + all_eval_batches = [ + self.eval_batches, + # pyre-ignore + *[d.eval_batches for d in other_datasets], + ] + if not ( + all(ba is None for ba in all_eval_batches) + or all(ba is not None for ba in all_eval_batches) + ): + raise ValueError( + "When joining datasets, either all joined datasets have to have their" + " eval_batches defined, or all should have their eval batches undefined." + ) + if self.eval_batches is not None: + self.eval_batches = sum(all_eval_batches, []) + self._invalidate_indexes(filter_seq_annots=True) + + def is_filtered(self) -> bool: """ Returns `True` in case the dataset has been filtered and thus some frame annotations stored on the disk might be missing in the dataset object. @@ -211,6 +250,7 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): seq_frame_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]], allow_missing_indices: bool = False, remove_missing_indices: bool = False, + suppress_missing_index_warning: bool = True, ) -> List[List[Union[Optional[int], int]]]: """ Obtain indices into the dataset object given a list of frame ids. @@ -228,6 +268,11 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): If `False`, returns `None` in place of `seq_frame_index` entries that are not present in the dataset. If `True` removes missing indices from the returned indices. + suppress_missing_index_warning: + Active if `allow_missing_indices==True`. Suppressess a warning message + in case an entry from `seq_frame_index` is missing in the dataset + (expected in certain cases - e.g. when setting + `self.remove_empty_masks=True`). Returns: dataset_idx: Indices of dataset entries corresponding to`seq_frame_index`. @@ -254,7 +299,8 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): ) if not allow_missing_indices: raise IndexError(msg) - warnings.warn(msg) + if not suppress_missing_index_warning: + warnings.warn(msg) return idx if path is not None: # Check that the loaded frame path is consistent @@ -288,6 +334,21 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): frame_index: List[Union[Tuple[str, int], Tuple[str, int, str]]], allow_missing_indices: bool = True, ) -> "JsonIndexDataset": + """ + Generate a dataset subset given the list of frames specified in `frame_index`. + + Args: + frame_index: The list of frame indentifiers (as stored in the metadata) + specified as `List[Tuple[sequence_name:str, frame_number:int]]`. Optionally, + Image paths relative to the dataset_root can be stored specified as well: + `List[Tuple[sequence_name:str, frame_number:int, image_path:str]]`, + in the latter case, if imaga_path do not match the stored paths, an error + is raised. + allow_missing_indices: If `False`, throws an IndexError upon reaching the first + entry from `frame_index` which is missing in the dataset. + Otherwise, generates a subset consisting of frames entries that actually + exist in the dataset. + """ # Get the indices into the frame annots. dataset_indices = self.seq_frame_index_to_dataset_index( [frame_index], @@ -838,6 +899,13 @@ class JsonIndexDataset(DatasetBase, ReplaceableBase): ) return out + def category_to_sequence_names(self) -> Dict[str, List[str]]: + c2seq = defaultdict(list) + # pyre-ignore + for sequence_name, sa in self.seq_annots.items(): + c2seq[sa.category].append(sequence_name) + return dict(c2seq) + def get_eval_batches(self) -> Optional[List[List[int]]]: return self.eval_batches diff --git a/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py index 7f3d1a8a..d8790d35 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py @@ -8,6 +8,7 @@ import copy import json import logging +import multiprocessing import os import warnings from collections import defaultdict @@ -30,6 +31,7 @@ from pytorch3d.implicitron.tools.config import ( ) from pytorch3d.renderer.cameras import CamerasBase +from tqdm import tqdm _CO3DV2_DATASET_ROOT: str = os.getenv("CO3DV2_DATASET_ROOT", "") @@ -147,7 +149,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13] (test frames can repeat across batches). Args: - category: The object category of the dataset. + category: Dataset categories to load expressed as a string of comma-separated + category names (e.g. `"apple,car,orange"`). subset_name: The name of the dataset subset. For CO3Dv2, these include e.g. "manyview_dev_0", "fewview_test", ... dataset_root: The root folder of the dataset. @@ -173,6 +176,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13] test_on_train: bool = False only_test_set: bool = False load_eval_batches: bool = True + num_load_workers: int = 4 n_known_frames_for_test: int = 0 @@ -189,11 +193,33 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13] if self.only_test_set and self.test_on_train: raise ValueError("Cannot have only_test_set and test_on_train") - frame_file = os.path.join( - self.dataset_root, self.category, "frame_annotations.jgz" - ) + if "," in self.category: + # a comma-separated list of categories to load + categories = [c.strip() for c in self.category.split(",")] + logger.info(f"Loading a list of categories: {str(categories)}.") + with multiprocessing.Pool( + processes=min(self.num_load_workers, len(categories)) + ) as pool: + category_dataset_maps = list( + tqdm( + pool.imap(self._load_category, categories), + total=len(categories), + ) + ) + dataset_map = category_dataset_maps[0] + dataset_map.join(category_dataset_maps[1:]) + + else: + # one category to load + dataset_map = self._load_category(self.category) + + self.dataset_map = dataset_map + + def _load_category(self, category: str) -> DatasetMap: + + frame_file = os.path.join(self.dataset_root, category, "frame_annotations.jgz") sequence_file = os.path.join( - self.dataset_root, self.category, "sequence_annotations.jgz" + self.dataset_root, category, "sequence_annotations.jgz" ) path_manager = self.path_manager_factory.get() @@ -232,7 +258,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13] dataset = dataset_type(**common_dataset_kwargs) - available_subset_names = self._get_available_subset_names() + available_subset_names = self._get_available_subset_names(category) logger.debug(f"Available subset names: {str(available_subset_names)}.") if self.subset_name not in available_subset_names: raise ValueError( @@ -242,20 +268,20 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13] # load the list of train/val/test frames subset_mapping = self._load_annotation_json( - os.path.join( - self.category, "set_lists", f"set_lists_{self.subset_name}.json" - ) + os.path.join(category, "set_lists", f"set_lists_{self.subset_name}.json") ) # load the evaluation batches if self.load_eval_batches: eval_batch_index = self._load_annotation_json( os.path.join( - self.category, + category, "eval_batches", f"eval_batches_{self.subset_name}.json", ) ) + else: + eval_batch_index = None train_dataset = None if not self.only_test_set: @@ -313,9 +339,7 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13] ) logger.info(f"# eval batches: {len(test_dataset.eval_batches)}") - self.dataset_map = DatasetMap( - train=train_dataset, val=val_dataset, test=test_dataset - ) + return DatasetMap(train=train_dataset, val=val_dataset, test=test_dataset) @classmethod def dataset_tweak_args(cls, type, args: DictConfig) -> None: @@ -381,10 +405,10 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13] data = json.load(f) return data - def _get_available_subset_names(self): + def _get_available_subset_names(self, category: str): return get_available_subset_names( self.dataset_root, - self.category, + category, path_manager=self.path_manager_factory.get(), ) diff --git a/pytorch3d/implicitron/dataset/scene_batch_sampler.py b/pytorch3d/implicitron/dataset/scene_batch_sampler.py index 2012e706..f724fd07 100644 --- a/pytorch3d/implicitron/dataset/scene_batch_sampler.py +++ b/pytorch3d/implicitron/dataset/scene_batch_sampler.py @@ -6,8 +6,9 @@ import warnings +from collections import Counter from dataclasses import dataclass, field -from typing import Iterable, Iterator, List, Sequence, Tuple +from typing import Dict, Iterable, Iterator, List, Sequence, Tuple import numpy as np from torch.utils.data.sampler import Sampler @@ -42,8 +43,17 @@ class SceneBatchSampler(Sampler[List[int]]): # same but for timestamps if they are available consecutive_frames_max_gap_seconds: float = 0.1 + # if True, the sampler first reads from the dataset the mapping between + # sequence names and their categories. + # During batch sampling, the sampler ensures uniform distribution over the categories + # of the sampled sequences. + category_aware: bool = True + seq_names: List[str] = field(init=False) + category_to_sequence_names: Dict[str, List[str]] = field(init=False) + categories: List[str] = field(init=False) + def __post_init__(self) -> None: if self.batch_size <= 0: raise ValueError( @@ -56,6 +66,10 @@ class SceneBatchSampler(Sampler[List[int]]): self.seq_names = list(self.dataset.sequence_names()) + if self.category_aware: + self.category_to_sequence_names = self.dataset.category_to_sequence_names() + self.categories = list(self.category_to_sequence_names.keys()) + def __len__(self) -> int: return self.num_batches @@ -67,7 +81,25 @@ class SceneBatchSampler(Sampler[List[int]]): def _sample_batch(self, batch_idx) -> List[int]: n_per_seq = np.random.choice(self.images_per_seq_options) n_seqs = -(-self.batch_size // n_per_seq) # round up - chosen_seq = _capped_random_choice(self.seq_names, n_seqs, replace=False) + + if self.category_aware: + # first sample categories at random, these can be repeated in the batch + chosen_cat = _capped_random_choice(self.categories, n_seqs, replace=True) + # then randomly sample a set of unique sequences within each category + chosen_seq = [] + for cat, n_per_category in Counter(chosen_cat).items(): + category_chosen_seq = _capped_random_choice( + self.category_to_sequence_names[cat], + n_per_category, + replace=False, + ) + chosen_seq.extend([str(s) for s in category_chosen_seq]) + else: + chosen_seq = _capped_random_choice( + self.seq_names, + n_seqs, + replace=False, + ) if self.sample_consecutive_frames: frame_idx = [] diff --git a/tests/implicitron/data/data_source.yaml b/tests/implicitron/data/data_source.yaml index bd6f4aad..cb10134e 100644 --- a/tests/implicitron/data/data_source.yaml +++ b/tests/implicitron/data/data_source.yaml @@ -49,6 +49,7 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args: test_on_train: false only_test_set: false load_eval_batches: true + num_load_workers: 4 n_known_frames_for_test: 0 dataset_class_type: JsonIndexDataset path_manager_factory_class_type: PathManagerFactory diff --git a/tests/implicitron/test_batch_sampler.py b/tests/implicitron/test_batch_sampler.py index 50d0b8fe..aba57551 100644 --- a/tests/implicitron/test_batch_sampler.py +++ b/tests/implicitron/test_batch_sampler.py @@ -11,17 +11,20 @@ from dataclasses import dataclass from itertools import product import numpy as np + +import torch from pytorch3d.implicitron.dataset.data_loader_map_provider import ( DoublePoolBatchSampler, ) -from pytorch3d.implicitron.dataset.dataset_base import DatasetBase +from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData from pytorch3d.implicitron.dataset.scene_batch_sampler import SceneBatchSampler @dataclass class MockFrameAnnotation: frame_number: int + sequence_name: str = "sequence" frame_timestamp: float = 0.0 @@ -41,6 +44,9 @@ class MockDataset(DatasetBase): self.frame_annots = [ {"frame_annotation": MockFrameAnnotation(no)} for no in frame_nos * num_seq ] + for seq_name, idx in self._seq_to_idx.items(): + for i in idx: + self.frame_annots[i]["frame_annotation"].sequence_name = seq_name def get_frame_numbers_and_timestamps(self, idxs): out = [] @@ -51,6 +57,16 @@ class MockDataset(DatasetBase): ) return out + def __getitem__(self, index: int): + fa = self.frame_annots[index]["frame_annotation"] + fd = FrameData( + sequence_name=fa.sequence_name, + sequence_category="default_category", + frame_number=torch.LongTensor([fa.frame_number]), + frame_timestamp=torch.LongTensor([fa.frame_timestamp]), + ) + return fd + class TestSceneBatchSampler(unittest.TestCase): def setUp(self): diff --git a/tests/implicitron/test_json_index_dataset_provider_v2.py b/tests/implicitron/test_json_index_dataset_provider_v2.py index 04ea2eaa..3191c0ee 100644 --- a/tests/implicitron/test_json_index_dataset_provider_v2.py +++ b/tests/implicitron/test_json_index_dataset_provider_v2.py @@ -41,22 +41,73 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase): categories = ["A", "B"] subset_name = "test" eval_batch_size = 5 + n_frames = 8 * 3 + n_sequences = 5 + n_eval_batches = 10 with tempfile.TemporaryDirectory() as tmpd: _make_random_json_dataset_map_provider_v2_data( tmpd, categories, eval_batch_size=eval_batch_size, + n_frames=n_frames, + n_sequences=n_sequences, + n_eval_batches=n_eval_batches, ) for n_known_frames_for_test in [0, 2]: - for category in categories: - dataset_provider = JsonIndexDatasetMapProviderV2( + dataset_providers = { + category: JsonIndexDatasetMapProviderV2( category=category, subset_name="test", dataset_root=tmpd, n_known_frames_for_test=n_known_frames_for_test, ) + for category in [*categories, ",".join(sorted(categories))] + } + for category, dataset_provider in dataset_providers.items(): dataset_map = dataset_provider.get_dataset_map() for set_ in ["train", "val", "test"]: + dataset = getattr(dataset_map, set_) + + cat2seq = dataset.category_to_sequence_names() + self.assertEqual(",".join(sorted(cat2seq.keys())), category) + + if not (n_known_frames_for_test != 0 and set_ == "test"): + # check the lengths only in case we do not have the + # n_known_frames_for_test set + expected_dataset_len = n_frames * n_sequences // 3 + if "," in category: + # multicategory json index dataset, sum the lengths of + # category-specific ones + expected_dataset_len = sum( + len( + getattr( + dataset_providers[c].get_dataset_map(), set_ + ) + ) + for c in categories + ) + self.assertEqual( + sum(len(s) for s in cat2seq.values()), + n_sequences * len(categories), + ) + self.assertEqual(len(cat2seq), len(categories)) + else: + self.assertEqual( + len(cat2seq[category]), + n_sequences, + ) + self.assertEqual(len(cat2seq), 1) + self.assertEqual(len(dataset), expected_dataset_len) + + if set_ == "test": + # check the number of eval batches + expected_n_eval_batches = n_eval_batches + if "," in category: + expected_n_eval_batches *= len(categories) + self.assertTrue( + len(dataset.get_eval_batches()) + == expected_n_eval_batches + ) if set_ in ["train", "val"]: dataloader = torch.utils.data.DataLoader( getattr(dataset_map, set_), @@ -80,6 +131,7 @@ class TestJsonIndexDatasetProviderV2(unittest.TestCase): dataset_provider.get_category_to_subset_name_list() ) category_to_subset_list_ = {c: [subset_name] for c in categories} + self.assertTrue(category_to_subset_list == category_to_subset_list_) @@ -88,6 +140,7 @@ def _make_random_json_dataset_map_provider_v2_data( categories: List[str], n_frames: int = 8, n_sequences: int = 5, + n_eval_batches: int = 10, H: int = 50, W: int = 30, subset_name: str = "test", @@ -100,7 +153,7 @@ def _make_random_json_dataset_map_provider_v2_data( sequence_annotations = [] frame_index = [] for seq_i in range(n_sequences): - seq_name = str(seq_i) + seq_name = category + str(seq_i) for i in range(n_frames): # generate and store image imdir = os.path.join(root, category, seq_name, "images") @@ -165,7 +218,8 @@ def _make_random_json_dataset_map_provider_v2_data( json.dump(set_list, f) eval_batches = [ - random.sample(test_frame_index, eval_batch_size) for _ in range(10) + random.sample(test_frame_index, eval_batch_size) + for _ in range(n_eval_batches) ] eval_b_dir = os.path.join(root, category, "eval_batches")