data_loader_map_provider

Summary: replace dataloader_zoo with a pluggable DataLoaderMapProvider.

Reviewed By: shapovalov

Differential Revision: D36475441

fbshipit-source-id: d16abb190d876940434329928f2e3f2794a25416
This commit is contained in:
Jeremy Reizenstein 2022-05-20 07:50:30 -07:00 committed by Facebook GitHub Bot
parent 79c61a2d86
commit 0f12c51646
10 changed files with 166 additions and 145 deletions

View File

@ -237,7 +237,7 @@ generic_model_args: GenericModel
solver_args: init_optimizer
data_source_args: ImplicitronDataSource
└-- dataset_map_provider_*_args
└-- dataloader_args
└-- data_loader_map_provider_*_args
```
Please look at the annotations of the respective classes or functions for the lists of hyperparameters.

View File

@ -7,7 +7,7 @@ visualize_interval: 0
visdom_port: 8097
data_source_args:
dataset_provider_class_type: JsonIndexDatasetMapProvider
dataloader_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1

View File

@ -2,7 +2,7 @@ defaults:
- repro_base.yaml
- _self_
data_source_args:
dataloader_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1

View File

@ -2,7 +2,7 @@ defaults:
- repro_base
- _self_
data_source_args:
dataloader_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 1
dataset_len: 1000
dataset_len_val: 1

View File

@ -2,7 +2,7 @@ defaults:
- repro_singleseq_base
- _self_
data_source_args:
dataloader_args:
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 10
dataset_len: 1000
dataset_len_val: 1

View File

@ -64,8 +64,8 @@ import tqdm
from omegaconf import DictConfig, OmegaConf
from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
@ -553,7 +553,7 @@ def _eval_and_dump(
cfg,
task: Task,
datasets: DatasetMap,
dataloaders: Dataloaders,
dataloaders: DataLoaderMap,
model,
stats,
device,
@ -566,7 +566,7 @@ def _eval_and_dump(
dataloader = dataloaders.test
if dataloader is None:
raise ValueError('Dataloaders have to contain the "test" entry for eval!')
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
if task == Task.SINGLE_SEQUENCE:
if datasets.train is None:

View File

@ -0,0 +1,139 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Sequence
import torch
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
from .dataset_base import FrameData, ImplicitronDatasetBase
from .dataset_map_provider import DatasetMap
from .scene_batch_sampler import SceneBatchSampler
@dataclass
class DataLoaderMap:
"""
A collection of data loaders for Implicitron.
Members:
train: a data loader for training
val: a data loader for validating during training
test: a data loader for final evaluation
"""
train: Optional[torch.utils.data.DataLoader[FrameData]]
val: Optional[torch.utils.data.DataLoader[FrameData]]
test: Optional[torch.utils.data.DataLoader[FrameData]]
def __getitem__(
self, split: str
) -> Optional[torch.utils.data.DataLoader[FrameData]]:
"""
Get one of the data loaders by key (name of data split)
"""
if split not in ["train", "val", "test"]:
raise ValueError(f"{split} was not a valid split name (train/val/test)")
return getattr(self, split)
class DataLoaderMapProviderBase(ReplaceableBase):
"""
Provider of a collection of data loaders for a given collection of datasets.
"""
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
"""
Returns a collection of data loaders for a given collection of datasets.
"""
raise NotImplementedError()
@registry.register
class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
"""
The default implementation of DataLoaderMapProviderBase.
Members:
batch_size: The size of the batch of the data loader.
num_workers: Number data-loading threads.
dataset_len: The number of batches in a training epoch.
dataset_len_val: The number of batches in a validation epoch.
images_per_seq_options: Possible numbers of images sampled per sequence.
sample_consecutive_frames: if True, will sample a contiguous interval of frames
in the sequence. It first sorts the frames by timestimps when available,
otherwise by frame numbers, finds the connected segments within the sequence
of sufficient length, then samples a random pivot element among them and
ideally uses it as a middle of the temporal window, shifting the borders
where necessary. This strategy mitigates the bias against shorter segments
and their boundaries.
consecutive_frames_max_gap: if a number > 0, then used to define the maximum
difference in frame_number of neighbouring frames when forming connected
segments; if both this and consecutive_frames_max_gap_seconds are 0s,
the whole sequence is considered a segment regardless of frame numbers.
consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the
maximum difference in frame_timestamp of neighbouring frames when forming
connected segments; if both this and consecutive_frames_max_gap are 0s,
the whole sequence is considered a segment regardless of frame timestamps.
"""
batch_size: int = 1
num_workers: int = 0
dataset_len: int = 1000
dataset_len_val: int = 1
images_per_seq_options: Sequence[int] = (2,)
sample_consecutive_frames: bool = False
consecutive_frames_max_gap: int = 0
consecutive_frames_max_gap_seconds: float = 0.1
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
"""
Returns a collection of data loaders for a given collection of datasets.
"""
data_loader_kwargs = {
"num_workers": self.num_workers,
"collate_fn": FrameData.collate,
}
def train_or_val_loader(
dataset: Optional[ImplicitronDatasetBase], num_batches: int
) -> Optional[torch.utils.data.DataLoader]:
if dataset is None:
return None
batch_sampler = SceneBatchSampler(
dataset,
self.batch_size,
num_batches=len(dataset) if num_batches <= 0 else num_batches,
images_per_seq_options=self.images_per_seq_options,
sample_consecutive_frames=self.sample_consecutive_frames,
consecutive_frames_max_gap=self.consecutive_frames_max_gap,
consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds,
)
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
**data_loader_kwargs,
)
train_data_loader = train_or_val_loader(datasets.train, self.dataset_len)
val_data_loader = train_or_val_loader(datasets.val, self.dataset_len_val)
test_dataset = datasets.test
if test_dataset is not None:
test_data_loader = torch.utils.data.DataLoader(
test_dataset,
batch_sampler=test_dataset.get_eval_batches(),
**data_loader_kwargs,
)
else:
test_data_loader = None
return DataLoaderMap(
train=train_data_loader, val=val_data_loader, test=test_data_loader
)

View File

@ -6,15 +6,10 @@
from typing import Tuple
from omegaconf import DictConfig
from pytorch3d.implicitron.tools.config import (
get_default_args_field,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.implicitron.tools.config import ReplaceableBase, run_auto_creation
from . import json_index_dataset_map_provider # noqa
from .dataloader_zoo import dataloader_zoo, Dataloaders
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
@ -24,7 +19,7 @@ class DataSourceBase(ReplaceableBase):
and DataLoader configuration.
"""
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]:
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
raise NotImplementedError()
@ -36,18 +31,20 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
Members:
dataset_map_provider_class_type: identifies type for dataset_map_provider.
e.g. JsonIndexDatasetMapProvider for Co3D.
data_loader_map_provider_class_type: identifies type for data_loader_map_provider.
"""
dataset_map_provider: DatasetMapProviderBase
dataset_map_provider_class_type: str
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
data_loader_map_provider: DataLoaderMapProviderBase
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
def __post_init__(self):
run_auto_creation(self)
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]:
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
datasets = self.dataset_map_provider.get_dataset_map()
dataloaders = dataloader_zoo(datasets, **self.dataloader_args)
dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets)
return datasets, dataloaders
def get_task(self) -> Task:

View File

@ -1,116 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Sequence
import torch
from pytorch3d.implicitron.tools.config import enable_get_default_args
from .dataset_base import FrameData, ImplicitronDatasetBase
from .dataset_map_provider import DatasetMap
from .scene_batch_sampler import SceneBatchSampler
@dataclass
class Dataloaders:
"""
A provider of dataloaders for implicitron.
Members:
train: a dataloader for training
val: a dataloader for validating during training
test: a dataloader for final evaluation
"""
train: Optional[torch.utils.data.DataLoader[FrameData]]
val: Optional[torch.utils.data.DataLoader[FrameData]]
test: Optional[torch.utils.data.DataLoader[FrameData]]
def dataloader_zoo(
datasets: DatasetMap,
batch_size: int = 1,
num_workers: int = 0,
dataset_len: int = 1000,
dataset_len_val: int = 1,
images_per_seq_options: Sequence[int] = (2,),
sample_consecutive_frames: bool = False,
consecutive_frames_max_gap: int = 0,
consecutive_frames_max_gap_seconds: float = 0.1,
) -> Dataloaders:
"""
Returns a set of dataloaders for a given set of datasets.
Args:
datasets: A dictionary containing the
`"dataset_subset_name": torch_dataset_object` key, value pairs.
batch_size: The size of the batch of the dataloader.
num_workers: Number data-loading threads.
dataset_len: The number of batches in a training epoch.
dataset_len_val: The number of batches in a validation epoch.
images_per_seq_options: Possible numbers of images sampled per sequence.
sample_consecutive_frames: if True, will sample a contiguous interval of frames
in the sequence. It first sorts the frames by timestimps when available,
otherwise by frame numbers, finds the connected segments within the sequence
of sufficient length, then samples a random pivot element among them and
ideally uses it as a middle of the temporal window, shifting the borders
where necessary. This strategy mitigates the bias against shorter segments
and their boundaries.
consecutive_frames_max_gap: if a number > 0, then used to define the maximum
difference in frame_number of neighbouring frames when forming connected
segments; if both this and consecutive_frames_max_gap_seconds are 0s,
the whole sequence is considered a segment regardless of frame numbers.
consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the
maximum difference in frame_timestamp of neighbouring frames when forming
connected segments; if both this and consecutive_frames_max_gap are 0s,
the whole sequence is considered a segment regardless of frame timestamps.
Returns:
dataloaders: A dictionary containing the
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
"""
dataloader_kwargs = {"num_workers": num_workers, "collate_fn": FrameData.collate}
def train_or_val_loader(
dataset: Optional[ImplicitronDatasetBase], num_batches: int
) -> Optional[torch.utils.data.DataLoader]:
if dataset is None:
return None
batch_sampler = SceneBatchSampler(
dataset,
batch_size,
num_batches=len(dataset) if num_batches <= 0 else num_batches,
images_per_seq_options=images_per_seq_options,
sample_consecutive_frames=sample_consecutive_frames,
consecutive_frames_max_gap=consecutive_frames_max_gap,
consecutive_frames_max_gap_seconds=consecutive_frames_max_gap_seconds,
)
return torch.utils.data.DataLoader(
dataset,
batch_sampler=batch_sampler,
**dataloader_kwargs,
)
train_dataloader = train_or_val_loader(datasets.train, dataset_len)
val_dataloader = train_or_val_loader(datasets.val, dataset_len_val)
test_dataset = datasets.test
if test_dataset is not None:
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_sampler=test_dataset.get_eval_batches(),
**dataloader_kwargs,
)
else:
test_dataloader = None
return Dataloaders(train=train_dataloader, val=val_dataloader, test=test_dataloader)
enable_get_default_args(dataloader_zoo)

View File

@ -1,14 +1,5 @@
dataset_map_provider_class_type: ???
dataloader_args:
batch_size: 1
num_workers: 0
dataset_len: 1000
dataset_len_val: 1
images_per_seq_options:
- 2
sample_consecutive_frames: false
consecutive_frames_max_gap: 0
consecutive_frames_max_gap_seconds: 0.1
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
dataset_map_provider_JsonIndexDatasetMapProvider_args:
category: ???
task_str: singlesequence
@ -31,3 +22,13 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args:
image_height: 800
remove_empty_masks: true
path_manager: null
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
batch_size: 1
num_workers: 0
dataset_len: 1000
dataset_len_val: 1
images_per_seq_options:
- 2
sample_consecutive_frames: false
consecutive_frames_max_gap: 0
consecutive_frames_max_gap_seconds: 0.1