From 0f12c51646a5754122c6375103f06ac8ee8fca7d Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Fri, 20 May 2022 07:50:30 -0700 Subject: [PATCH] data_loader_map_provider Summary: replace dataloader_zoo with a pluggable DataLoaderMapProvider. Reviewed By: shapovalov Differential Revision: D36475441 fbshipit-source-id: d16abb190d876940434329928f2e3f2794a25416 --- projects/implicitron_trainer/README.md | 2 +- .../configs/repro_base.yaml | 2 +- .../configs/repro_multiseq_base.yaml | 2 +- .../configs/repro_singleseq_base.yaml | 2 +- .../configs/repro_singleseq_wce_base.yaml | 2 +- projects/implicitron_trainer/experiment.py | 6 +- .../dataset/data_loader_map_provider.py | 139 ++++++++++++++++++ pytorch3d/implicitron/dataset/data_source.py | 19 +-- .../implicitron/dataset/dataloader_zoo.py | 116 --------------- tests/implicitron/data/data_source.yaml | 21 +-- 10 files changed, 166 insertions(+), 145 deletions(-) create mode 100644 pytorch3d/implicitron/dataset/data_loader_map_provider.py delete mode 100644 pytorch3d/implicitron/dataset/dataloader_zoo.py diff --git a/projects/implicitron_trainer/README.md b/projects/implicitron_trainer/README.md index 80253a20..5106f72c 100644 --- a/projects/implicitron_trainer/README.md +++ b/projects/implicitron_trainer/README.md @@ -237,7 +237,7 @@ generic_model_args: GenericModel solver_args: init_optimizer data_source_args: ImplicitronDataSource └-- dataset_map_provider_*_args -└-- dataloader_args +└-- data_loader_map_provider_*_args ``` Please look at the annotations of the respective classes or functions for the lists of hyperparameters. diff --git a/projects/implicitron_trainer/configs/repro_base.yaml b/projects/implicitron_trainer/configs/repro_base.yaml index 21fc39c5..041ba47b 100644 --- a/projects/implicitron_trainer/configs/repro_base.yaml +++ b/projects/implicitron_trainer/configs/repro_base.yaml @@ -7,7 +7,7 @@ visualize_interval: 0 visdom_port: 8097 data_source_args: dataset_provider_class_type: JsonIndexDatasetMapProvider - dataloader_args: + data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 10 dataset_len: 1000 dataset_len_val: 1 diff --git a/projects/implicitron_trainer/configs/repro_multiseq_base.yaml b/projects/implicitron_trainer/configs/repro_multiseq_base.yaml index ae3eae32..a659a52b 100644 --- a/projects/implicitron_trainer/configs/repro_multiseq_base.yaml +++ b/projects/implicitron_trainer/configs/repro_multiseq_base.yaml @@ -2,7 +2,7 @@ defaults: - repro_base.yaml - _self_ data_source_args: - dataloader_args: + data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 10 dataset_len: 1000 dataset_len_val: 1 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_base.yaml b/projects/implicitron_trainer/configs/repro_singleseq_base.yaml index 1419b7e0..177e6613 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_base.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_base.yaml @@ -2,7 +2,7 @@ defaults: - repro_base - _self_ data_source_args: - dataloader_args: + data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 1 dataset_len: 1000 dataset_len_val: 1 diff --git a/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml b/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml index 57de6cf4..b714490b 100644 --- a/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml +++ b/projects/implicitron_trainer/configs/repro_singleseq_wce_base.yaml @@ -2,7 +2,7 @@ defaults: - repro_singleseq_base - _self_ data_source_args: - dataloader_args: + data_loader_map_provider_SequenceDataLoaderMapProvider_args: batch_size: 10 dataset_len: 1000 dataset_len_val: 1 diff --git a/projects/implicitron_trainer/experiment.py b/projects/implicitron_trainer/experiment.py index db6e0591..0ce35e5a 100755 --- a/projects/implicitron_trainer/experiment.py +++ b/projects/implicitron_trainer/experiment.py @@ -64,8 +64,8 @@ import tqdm from omegaconf import DictConfig, OmegaConf from packaging import version from pytorch3d.implicitron.dataset import utils as ds_utils +from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task -from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders from pytorch3d.implicitron.dataset.dataset_base import FrameData from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset @@ -553,7 +553,7 @@ def _eval_and_dump( cfg, task: Task, datasets: DatasetMap, - dataloaders: Dataloaders, + dataloaders: DataLoaderMap, model, stats, device, @@ -566,7 +566,7 @@ def _eval_and_dump( dataloader = dataloaders.test if dataloader is None: - raise ValueError('Dataloaders have to contain the "test" entry for eval!') + raise ValueError('DataLoaderMap have to contain the "test" entry for eval!') if task == Task.SINGLE_SEQUENCE: if datasets.train is None: diff --git a/pytorch3d/implicitron/dataset/data_loader_map_provider.py b/pytorch3d/implicitron/dataset/data_loader_map_provider.py new file mode 100644 index 00000000..d22a4678 --- /dev/null +++ b/pytorch3d/implicitron/dataset/data_loader_map_provider.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from typing import Optional, Sequence + +import torch +from pytorch3d.implicitron.tools.config import registry, ReplaceableBase + +from .dataset_base import FrameData, ImplicitronDatasetBase +from .dataset_map_provider import DatasetMap +from .scene_batch_sampler import SceneBatchSampler + + +@dataclass +class DataLoaderMap: + """ + A collection of data loaders for Implicitron. + + Members: + + train: a data loader for training + val: a data loader for validating during training + test: a data loader for final evaluation + """ + + train: Optional[torch.utils.data.DataLoader[FrameData]] + val: Optional[torch.utils.data.DataLoader[FrameData]] + test: Optional[torch.utils.data.DataLoader[FrameData]] + + def __getitem__( + self, split: str + ) -> Optional[torch.utils.data.DataLoader[FrameData]]: + """ + Get one of the data loaders by key (name of data split) + """ + if split not in ["train", "val", "test"]: + raise ValueError(f"{split} was not a valid split name (train/val/test)") + return getattr(self, split) + + +class DataLoaderMapProviderBase(ReplaceableBase): + """ + Provider of a collection of data loaders for a given collection of datasets. + """ + + def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap: + """ + Returns a collection of data loaders for a given collection of datasets. + """ + raise NotImplementedError() + + +@registry.register +class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase): + """ + The default implementation of DataLoaderMapProviderBase. + + Members: + batch_size: The size of the batch of the data loader. + num_workers: Number data-loading threads. + dataset_len: The number of batches in a training epoch. + dataset_len_val: The number of batches in a validation epoch. + images_per_seq_options: Possible numbers of images sampled per sequence. + sample_consecutive_frames: if True, will sample a contiguous interval of frames + in the sequence. It first sorts the frames by timestimps when available, + otherwise by frame numbers, finds the connected segments within the sequence + of sufficient length, then samples a random pivot element among them and + ideally uses it as a middle of the temporal window, shifting the borders + where necessary. This strategy mitigates the bias against shorter segments + and their boundaries. + consecutive_frames_max_gap: if a number > 0, then used to define the maximum + difference in frame_number of neighbouring frames when forming connected + segments; if both this and consecutive_frames_max_gap_seconds are 0s, + the whole sequence is considered a segment regardless of frame numbers. + consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the + maximum difference in frame_timestamp of neighbouring frames when forming + connected segments; if both this and consecutive_frames_max_gap are 0s, + the whole sequence is considered a segment regardless of frame timestamps. + """ + + batch_size: int = 1 + num_workers: int = 0 + dataset_len: int = 1000 + dataset_len_val: int = 1 + images_per_seq_options: Sequence[int] = (2,) + sample_consecutive_frames: bool = False + consecutive_frames_max_gap: int = 0 + consecutive_frames_max_gap_seconds: float = 0.1 + + def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap: + """ + Returns a collection of data loaders for a given collection of datasets. + """ + + data_loader_kwargs = { + "num_workers": self.num_workers, + "collate_fn": FrameData.collate, + } + + def train_or_val_loader( + dataset: Optional[ImplicitronDatasetBase], num_batches: int + ) -> Optional[torch.utils.data.DataLoader]: + if dataset is None: + return None + batch_sampler = SceneBatchSampler( + dataset, + self.batch_size, + num_batches=len(dataset) if num_batches <= 0 else num_batches, + images_per_seq_options=self.images_per_seq_options, + sample_consecutive_frames=self.sample_consecutive_frames, + consecutive_frames_max_gap=self.consecutive_frames_max_gap, + consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds, + ) + return torch.utils.data.DataLoader( + dataset, + batch_sampler=batch_sampler, + **data_loader_kwargs, + ) + + train_data_loader = train_or_val_loader(datasets.train, self.dataset_len) + val_data_loader = train_or_val_loader(datasets.val, self.dataset_len_val) + + test_dataset = datasets.test + if test_dataset is not None: + test_data_loader = torch.utils.data.DataLoader( + test_dataset, + batch_sampler=test_dataset.get_eval_batches(), + **data_loader_kwargs, + ) + else: + test_data_loader = None + + return DataLoaderMap( + train=train_data_loader, val=val_data_loader, test=test_data_loader + ) diff --git a/pytorch3d/implicitron/dataset/data_source.py b/pytorch3d/implicitron/dataset/data_source.py index 8723e089..dc2e7054 100644 --- a/pytorch3d/implicitron/dataset/data_source.py +++ b/pytorch3d/implicitron/dataset/data_source.py @@ -6,15 +6,10 @@ from typing import Tuple -from omegaconf import DictConfig -from pytorch3d.implicitron.tools.config import ( - get_default_args_field, - ReplaceableBase, - run_auto_creation, -) +from pytorch3d.implicitron.tools.config import ReplaceableBase, run_auto_creation from . import json_index_dataset_map_provider # noqa -from .dataloader_zoo import dataloader_zoo, Dataloaders +from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task @@ -24,7 +19,7 @@ class DataSourceBase(ReplaceableBase): and DataLoader configuration. """ - def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]: + def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]: raise NotImplementedError() @@ -36,18 +31,20 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13] Members: dataset_map_provider_class_type: identifies type for dataset_map_provider. e.g. JsonIndexDatasetMapProvider for Co3D. + data_loader_map_provider_class_type: identifies type for data_loader_map_provider. """ dataset_map_provider: DatasetMapProviderBase dataset_map_provider_class_type: str - dataloader_args: DictConfig = get_default_args_field(dataloader_zoo) + data_loader_map_provider: DataLoaderMapProviderBase + data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider" def __post_init__(self): run_auto_creation(self) - def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]: + def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]: datasets = self.dataset_map_provider.get_dataset_map() - dataloaders = dataloader_zoo(datasets, **self.dataloader_args) + dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets) return datasets, dataloaders def get_task(self) -> Task: diff --git a/pytorch3d/implicitron/dataset/dataloader_zoo.py b/pytorch3d/implicitron/dataset/dataloader_zoo.py deleted file mode 100644 index 58cb576e..00000000 --- a/pytorch3d/implicitron/dataset/dataloader_zoo.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from dataclasses import dataclass -from typing import Optional, Sequence - -import torch -from pytorch3d.implicitron.tools.config import enable_get_default_args - -from .dataset_base import FrameData, ImplicitronDatasetBase -from .dataset_map_provider import DatasetMap -from .scene_batch_sampler import SceneBatchSampler - - -@dataclass -class Dataloaders: - """ - A provider of dataloaders for implicitron. - - Members: - - train: a dataloader for training - val: a dataloader for validating during training - test: a dataloader for final evaluation - """ - - train: Optional[torch.utils.data.DataLoader[FrameData]] - val: Optional[torch.utils.data.DataLoader[FrameData]] - test: Optional[torch.utils.data.DataLoader[FrameData]] - - -def dataloader_zoo( - datasets: DatasetMap, - batch_size: int = 1, - num_workers: int = 0, - dataset_len: int = 1000, - dataset_len_val: int = 1, - images_per_seq_options: Sequence[int] = (2,), - sample_consecutive_frames: bool = False, - consecutive_frames_max_gap: int = 0, - consecutive_frames_max_gap_seconds: float = 0.1, -) -> Dataloaders: - """ - Returns a set of dataloaders for a given set of datasets. - - Args: - datasets: A dictionary containing the - `"dataset_subset_name": torch_dataset_object` key, value pairs. - batch_size: The size of the batch of the dataloader. - num_workers: Number data-loading threads. - dataset_len: The number of batches in a training epoch. - dataset_len_val: The number of batches in a validation epoch. - images_per_seq_options: Possible numbers of images sampled per sequence. - sample_consecutive_frames: if True, will sample a contiguous interval of frames - in the sequence. It first sorts the frames by timestimps when available, - otherwise by frame numbers, finds the connected segments within the sequence - of sufficient length, then samples a random pivot element among them and - ideally uses it as a middle of the temporal window, shifting the borders - where necessary. This strategy mitigates the bias against shorter segments - and their boundaries. - consecutive_frames_max_gap: if a number > 0, then used to define the maximum - difference in frame_number of neighbouring frames when forming connected - segments; if both this and consecutive_frames_max_gap_seconds are 0s, - the whole sequence is considered a segment regardless of frame numbers. - consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the - maximum difference in frame_timestamp of neighbouring frames when forming - connected segments; if both this and consecutive_frames_max_gap are 0s, - the whole sequence is considered a segment regardless of frame timestamps. - - Returns: - dataloaders: A dictionary containing the - `"dataset_subset_name": torch_dataloader_object` key, value pairs. - """ - - dataloader_kwargs = {"num_workers": num_workers, "collate_fn": FrameData.collate} - - def train_or_val_loader( - dataset: Optional[ImplicitronDatasetBase], num_batches: int - ) -> Optional[torch.utils.data.DataLoader]: - if dataset is None: - return None - batch_sampler = SceneBatchSampler( - dataset, - batch_size, - num_batches=len(dataset) if num_batches <= 0 else num_batches, - images_per_seq_options=images_per_seq_options, - sample_consecutive_frames=sample_consecutive_frames, - consecutive_frames_max_gap=consecutive_frames_max_gap, - consecutive_frames_max_gap_seconds=consecutive_frames_max_gap_seconds, - ) - return torch.utils.data.DataLoader( - dataset, - batch_sampler=batch_sampler, - **dataloader_kwargs, - ) - - train_dataloader = train_or_val_loader(datasets.train, dataset_len) - val_dataloader = train_or_val_loader(datasets.val, dataset_len_val) - - test_dataset = datasets.test - if test_dataset is not None: - test_dataloader = torch.utils.data.DataLoader( - test_dataset, - batch_sampler=test_dataset.get_eval_batches(), - **dataloader_kwargs, - ) - else: - test_dataloader = None - - return Dataloaders(train=train_dataloader, val=val_dataloader, test=test_dataloader) - - -enable_get_default_args(dataloader_zoo) diff --git a/tests/implicitron/data/data_source.yaml b/tests/implicitron/data/data_source.yaml index bff7ae08..e71dac1f 100644 --- a/tests/implicitron/data/data_source.yaml +++ b/tests/implicitron/data/data_source.yaml @@ -1,14 +1,5 @@ dataset_map_provider_class_type: ??? -dataloader_args: - batch_size: 1 - num_workers: 0 - dataset_len: 1000 - dataset_len_val: 1 - images_per_seq_options: - - 2 - sample_consecutive_frames: false - consecutive_frames_max_gap: 0 - consecutive_frames_max_gap_seconds: 0.1 +data_loader_map_provider_class_type: SequenceDataLoaderMapProvider dataset_map_provider_JsonIndexDatasetMapProvider_args: category: ??? task_str: singlesequence @@ -31,3 +22,13 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args: image_height: 800 remove_empty_masks: true path_manager: null +data_loader_map_provider_SequenceDataLoaderMapProvider_args: + batch_size: 1 + num_workers: 0 + dataset_len: 1000 + dataset_len_val: 1 + images_per_seq_options: + - 2 + sample_consecutive_frames: false + consecutive_frames_max_gap: 0 + consecutive_frames_max_gap_seconds: 0.1