diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index 5f0d62e2..f810426b 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -117,6 +117,12 @@ data_source_ImplicitronDataSource_args: sample_consecutive_frames: false consecutive_frames_max_gap: 0 consecutive_frames_max_gap_seconds: 0.1 + data_loader_map_provider_SimpleDataLoaderMapProvider_args: + batch_size: 1 + num_workers: 0 + dataset_length_train: 0 + dataset_length_val: 0 + dataset_length_test: 0 model_factory_ImplicitronModelFactory_args: force_load: false model_class_type: GenericModel diff --git a/pytorch3d/implicitron/dataset/data_loader_map_provider.py b/pytorch3d/implicitron/dataset/data_loader_map_provider.py index d1c4d4e4..2a0de2ec 100644 --- a/pytorch3d/implicitron/dataset/data_loader_map_provider.py +++ b/pytorch3d/implicitron/dataset/data_loader_map_provider.py @@ -61,6 +61,84 @@ class DataLoaderMapProviderBase(ReplaceableBase): raise NotImplementedError() +@registry.register +class SimpleDataLoaderMapProvider(DataLoaderMapProviderBase): + """ + Trivial implementation of DataLoaderMapProviderBase. + + If a dataset returns batches from get_eval_batches(), then + they will be what the corresponding dataloader returns, + independently of any of the fields on this class. + + Otherwise, returns shuffled batches. + """ + + batch_size: int = 1 + num_workers: int = 0 + dataset_length_train: int = 0 + dataset_length_val: int = 0 + dataset_length_test: int = 0 + + def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap: + """ + Returns a collection of data loaders for a given collection of datasets. + """ + return DataLoaderMap( + train=self._make_data_loader( + datasets.train, + self.dataset_length_train, + ), + val=self._make_data_loader( + datasets.val, + self.dataset_length_val, + ), + test=self._make_data_loader( + datasets.test, + self.dataset_length_test, + ), + ) + + def _make_data_loader( + self, + dataset: Optional[DatasetBase], + num_batches: int, + ) -> Optional[DataLoader[FrameData]]: + """ + Returns the dataloader for a dataset. + + Args: + dataset: the dataset + num_batches: possible ceiling on number of batches per epoch + """ + if dataset is None: + return None + + data_loader_kwargs = { + "num_workers": self.num_workers, + "collate_fn": dataset.frame_data_type.collate, + } + + eval_batches = dataset.get_eval_batches() + if eval_batches is not None: + return DataLoader( + dataset, + batch_sampler=eval_batches, + **data_loader_kwargs, + ) + + if num_batches > 0: + num_samples = self.batch_size * num_batches + else: + num_samples = None + sampler = RandomSampler(dataset, replacement=True, num_samples=num_samples) + batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True) + return DataLoader( + dataset, + batch_sampler=batch_sampler, + **data_loader_kwargs, + ) + + class DoublePoolBatchSampler(Sampler[List[int]]): """ Batch sampler for making random batches of a single frame diff --git a/tests/implicitron/data/data_source.yaml b/tests/implicitron/data/data_source.yaml index 9528fdfa..3c819817 100644 --- a/tests/implicitron/data/data_source.yaml +++ b/tests/implicitron/data/data_source.yaml @@ -105,3 +105,9 @@ data_loader_map_provider_SequenceDataLoaderMapProvider_args: sample_consecutive_frames: false consecutive_frames_max_gap: 0 consecutive_frames_max_gap_seconds: 0.1 +data_loader_map_provider_SimpleDataLoaderMapProvider_args: + batch_size: 1 + num_workers: 0 + dataset_length_train: 0 + dataset_length_val: 0 + dataset_length_test: 0 diff --git a/tests/implicitron/test_data_source.py b/tests/implicitron/test_data_source.py index cf5033b9..950e0c10 100644 --- a/tests/implicitron/test_data_source.py +++ b/tests/implicitron/test_data_source.py @@ -10,6 +10,10 @@ import unittest.mock import torch from omegaconf import OmegaConf +from pytorch3d.implicitron.dataset.data_loader_map_provider import ( + SequenceDataLoaderMapProvider, + SimpleDataLoaderMapProvider, +) from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset from pytorch3d.implicitron.tools.config import get_default_args @@ -64,7 +68,6 @@ class TestDataSource(unittest.TestCase): return args = get_default_args(ImplicitronDataSource) args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider" - args.data_loader_map_provider_class_type = "SequenceDataLoaderMapProvider" dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args dataset_args.category = "skateboard" dataset_args.test_restrict_sequence_id = 0 @@ -73,8 +76,35 @@ class TestDataSource(unittest.TestCase): dataset_args.dataset_root = "manifold://co3d/tree/extracted" data_source = ImplicitronDataSource(**args) + self.assertIsInstance( + data_source.data_loader_map_provider, SequenceDataLoaderMapProvider + ) _, data_loaders = data_source.get_datasets_and_dataloaders() self.assertEqual(len(data_loaders.train), 81) for i in data_loaders.train: self.assertEqual(i.frame_type, ["test_known"]) break + + def test_simple(self): + if os.environ.get("INSIDE_RE_WORKER") is not None: + return + args = get_default_args(ImplicitronDataSource) + args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider" + args.data_loader_map_provider_class_type = "SimpleDataLoaderMapProvider" + dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args + dataset_args.category = "skateboard" + dataset_args.test_restrict_sequence_id = 0 + dataset_args.n_frames_per_sequence = -1 + + dataset_args.dataset_root = "manifold://co3d/tree/extracted" + + data_source = ImplicitronDataSource(**args) + self.assertIsInstance( + data_source.data_loader_map_provider, SimpleDataLoaderMapProvider + ) + _, data_loaders = data_source.get_datasets_and_dataloaders() + + self.assertEqual(len(data_loaders.train), 81) + for i in data_loaders.train: + self.assertEqual(i.frame_type, ["test_known"]) + break