mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
data_loader_map_provider
Summary: replace dataloader_zoo with a pluggable DataLoaderMapProvider. Reviewed By: shapovalov Differential Revision: D36475441 fbshipit-source-id: d16abb190d876940434329928f2e3f2794a25416
This commit is contained in:
parent
79c61a2d86
commit
0f12c51646
@ -237,7 +237,7 @@ generic_model_args: GenericModel
|
||||
solver_args: init_optimizer
|
||||
data_source_args: ImplicitronDataSource
|
||||
└-- dataset_map_provider_*_args
|
||||
└-- dataloader_args
|
||||
└-- data_loader_map_provider_*_args
|
||||
```
|
||||
|
||||
Please look at the annotations of the respective classes or functions for the lists of hyperparameters.
|
||||
|
@ -7,7 +7,7 @@ visualize_interval: 0
|
||||
visdom_port: 8097
|
||||
data_source_args:
|
||||
dataset_provider_class_type: JsonIndexDatasetMapProvider
|
||||
dataloader_args:
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
batch_size: 10
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
|
@ -2,7 +2,7 @@ defaults:
|
||||
- repro_base.yaml
|
||||
- _self_
|
||||
data_source_args:
|
||||
dataloader_args:
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
batch_size: 10
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
|
@ -2,7 +2,7 @@ defaults:
|
||||
- repro_base
|
||||
- _self_
|
||||
data_source_args:
|
||||
dataloader_args:
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
batch_size: 1
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
|
@ -2,7 +2,7 @@ defaults:
|
||||
- repro_singleseq_base
|
||||
- _self_
|
||||
data_source_args:
|
||||
dataloader_args:
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
batch_size: 10
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
|
@ -64,8 +64,8 @@ import tqdm
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from packaging import version
|
||||
from pytorch3d.implicitron.dataset import utils as ds_utils
|
||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import DataLoaderMap
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Task
|
||||
from pytorch3d.implicitron.dataset.dataloader_zoo import Dataloaders
|
||||
from pytorch3d.implicitron.dataset.dataset_base import FrameData
|
||||
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
|
||||
from pytorch3d.implicitron.dataset.implicitron_dataset import ImplicitronDataset
|
||||
@ -553,7 +553,7 @@ def _eval_and_dump(
|
||||
cfg,
|
||||
task: Task,
|
||||
datasets: DatasetMap,
|
||||
dataloaders: Dataloaders,
|
||||
dataloaders: DataLoaderMap,
|
||||
model,
|
||||
stats,
|
||||
device,
|
||||
@ -566,7 +566,7 @@ def _eval_and_dump(
|
||||
dataloader = dataloaders.test
|
||||
|
||||
if dataloader is None:
|
||||
raise ValueError('Dataloaders have to contain the "test" entry for eval!')
|
||||
raise ValueError('DataLoaderMap have to contain the "test" entry for eval!')
|
||||
|
||||
if task == Task.SINGLE_SEQUENCE:
|
||||
if datasets.train is None:
|
||||
|
139
pytorch3d/implicitron/dataset/data_loader_map_provider.py
Normal file
139
pytorch3d/implicitron/dataset/data_loader_map_provider.py
Normal file
@ -0,0 +1,139 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import registry, ReplaceableBase
|
||||
|
||||
from .dataset_base import FrameData, ImplicitronDatasetBase
|
||||
from .dataset_map_provider import DatasetMap
|
||||
from .scene_batch_sampler import SceneBatchSampler
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataLoaderMap:
|
||||
"""
|
||||
A collection of data loaders for Implicitron.
|
||||
|
||||
Members:
|
||||
|
||||
train: a data loader for training
|
||||
val: a data loader for validating during training
|
||||
test: a data loader for final evaluation
|
||||
"""
|
||||
|
||||
train: Optional[torch.utils.data.DataLoader[FrameData]]
|
||||
val: Optional[torch.utils.data.DataLoader[FrameData]]
|
||||
test: Optional[torch.utils.data.DataLoader[FrameData]]
|
||||
|
||||
def __getitem__(
|
||||
self, split: str
|
||||
) -> Optional[torch.utils.data.DataLoader[FrameData]]:
|
||||
"""
|
||||
Get one of the data loaders by key (name of data split)
|
||||
"""
|
||||
if split not in ["train", "val", "test"]:
|
||||
raise ValueError(f"{split} was not a valid split name (train/val/test)")
|
||||
return getattr(self, split)
|
||||
|
||||
|
||||
class DataLoaderMapProviderBase(ReplaceableBase):
|
||||
"""
|
||||
Provider of a collection of data loaders for a given collection of datasets.
|
||||
"""
|
||||
|
||||
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
|
||||
"""
|
||||
Returns a collection of data loaders for a given collection of datasets.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class SequenceDataLoaderMapProvider(DataLoaderMapProviderBase):
|
||||
"""
|
||||
The default implementation of DataLoaderMapProviderBase.
|
||||
|
||||
Members:
|
||||
batch_size: The size of the batch of the data loader.
|
||||
num_workers: Number data-loading threads.
|
||||
dataset_len: The number of batches in a training epoch.
|
||||
dataset_len_val: The number of batches in a validation epoch.
|
||||
images_per_seq_options: Possible numbers of images sampled per sequence.
|
||||
sample_consecutive_frames: if True, will sample a contiguous interval of frames
|
||||
in the sequence. It first sorts the frames by timestimps when available,
|
||||
otherwise by frame numbers, finds the connected segments within the sequence
|
||||
of sufficient length, then samples a random pivot element among them and
|
||||
ideally uses it as a middle of the temporal window, shifting the borders
|
||||
where necessary. This strategy mitigates the bias against shorter segments
|
||||
and their boundaries.
|
||||
consecutive_frames_max_gap: if a number > 0, then used to define the maximum
|
||||
difference in frame_number of neighbouring frames when forming connected
|
||||
segments; if both this and consecutive_frames_max_gap_seconds are 0s,
|
||||
the whole sequence is considered a segment regardless of frame numbers.
|
||||
consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the
|
||||
maximum difference in frame_timestamp of neighbouring frames when forming
|
||||
connected segments; if both this and consecutive_frames_max_gap are 0s,
|
||||
the whole sequence is considered a segment regardless of frame timestamps.
|
||||
"""
|
||||
|
||||
batch_size: int = 1
|
||||
num_workers: int = 0
|
||||
dataset_len: int = 1000
|
||||
dataset_len_val: int = 1
|
||||
images_per_seq_options: Sequence[int] = (2,)
|
||||
sample_consecutive_frames: bool = False
|
||||
consecutive_frames_max_gap: int = 0
|
||||
consecutive_frames_max_gap_seconds: float = 0.1
|
||||
|
||||
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
|
||||
"""
|
||||
Returns a collection of data loaders for a given collection of datasets.
|
||||
"""
|
||||
|
||||
data_loader_kwargs = {
|
||||
"num_workers": self.num_workers,
|
||||
"collate_fn": FrameData.collate,
|
||||
}
|
||||
|
||||
def train_or_val_loader(
|
||||
dataset: Optional[ImplicitronDatasetBase], num_batches: int
|
||||
) -> Optional[torch.utils.data.DataLoader]:
|
||||
if dataset is None:
|
||||
return None
|
||||
batch_sampler = SceneBatchSampler(
|
||||
dataset,
|
||||
self.batch_size,
|
||||
num_batches=len(dataset) if num_batches <= 0 else num_batches,
|
||||
images_per_seq_options=self.images_per_seq_options,
|
||||
sample_consecutive_frames=self.sample_consecutive_frames,
|
||||
consecutive_frames_max_gap=self.consecutive_frames_max_gap,
|
||||
consecutive_frames_max_gap_seconds=self.consecutive_frames_max_gap_seconds,
|
||||
)
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
**data_loader_kwargs,
|
||||
)
|
||||
|
||||
train_data_loader = train_or_val_loader(datasets.train, self.dataset_len)
|
||||
val_data_loader = train_or_val_loader(datasets.val, self.dataset_len_val)
|
||||
|
||||
test_dataset = datasets.test
|
||||
if test_dataset is not None:
|
||||
test_data_loader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_sampler=test_dataset.get_eval_batches(),
|
||||
**data_loader_kwargs,
|
||||
)
|
||||
else:
|
||||
test_data_loader = None
|
||||
|
||||
return DataLoaderMap(
|
||||
train=train_data_loader, val=val_data_loader, test=test_data_loader
|
||||
)
|
@ -6,15 +6,10 @@
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
get_default_args_field,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.config import ReplaceableBase, run_auto_creation
|
||||
|
||||
from . import json_index_dataset_map_provider # noqa
|
||||
from .dataloader_zoo import dataloader_zoo, Dataloaders
|
||||
from .data_loader_map_provider import DataLoaderMap, DataLoaderMapProviderBase
|
||||
from .dataset_map_provider import DatasetMap, DatasetMapProviderBase, Task
|
||||
|
||||
|
||||
@ -24,7 +19,7 @@ class DataSourceBase(ReplaceableBase):
|
||||
and DataLoader configuration.
|
||||
"""
|
||||
|
||||
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]:
|
||||
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@ -36,18 +31,20 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13]
|
||||
Members:
|
||||
dataset_map_provider_class_type: identifies type for dataset_map_provider.
|
||||
e.g. JsonIndexDatasetMapProvider for Co3D.
|
||||
data_loader_map_provider_class_type: identifies type for data_loader_map_provider.
|
||||
"""
|
||||
|
||||
dataset_map_provider: DatasetMapProviderBase
|
||||
dataset_map_provider_class_type: str
|
||||
dataloader_args: DictConfig = get_default_args_field(dataloader_zoo)
|
||||
data_loader_map_provider: DataLoaderMapProviderBase
|
||||
data_loader_map_provider_class_type: str = "SequenceDataLoaderMapProvider"
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, Dataloaders]:
|
||||
def get_datasets_and_dataloaders(self) -> Tuple[DatasetMap, DataLoaderMap]:
|
||||
datasets = self.dataset_map_provider.get_dataset_map()
|
||||
dataloaders = dataloader_zoo(datasets, **self.dataloader_args)
|
||||
dataloaders = self.data_loader_map_provider.get_data_loader_map(datasets)
|
||||
return datasets, dataloaders
|
||||
|
||||
def get_task(self) -> Task:
|
||||
|
@ -1,116 +0,0 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||
|
||||
from .dataset_base import FrameData, ImplicitronDatasetBase
|
||||
from .dataset_map_provider import DatasetMap
|
||||
from .scene_batch_sampler import SceneBatchSampler
|
||||
|
||||
|
||||
@dataclass
|
||||
class Dataloaders:
|
||||
"""
|
||||
A provider of dataloaders for implicitron.
|
||||
|
||||
Members:
|
||||
|
||||
train: a dataloader for training
|
||||
val: a dataloader for validating during training
|
||||
test: a dataloader for final evaluation
|
||||
"""
|
||||
|
||||
train: Optional[torch.utils.data.DataLoader[FrameData]]
|
||||
val: Optional[torch.utils.data.DataLoader[FrameData]]
|
||||
test: Optional[torch.utils.data.DataLoader[FrameData]]
|
||||
|
||||
|
||||
def dataloader_zoo(
|
||||
datasets: DatasetMap,
|
||||
batch_size: int = 1,
|
||||
num_workers: int = 0,
|
||||
dataset_len: int = 1000,
|
||||
dataset_len_val: int = 1,
|
||||
images_per_seq_options: Sequence[int] = (2,),
|
||||
sample_consecutive_frames: bool = False,
|
||||
consecutive_frames_max_gap: int = 0,
|
||||
consecutive_frames_max_gap_seconds: float = 0.1,
|
||||
) -> Dataloaders:
|
||||
"""
|
||||
Returns a set of dataloaders for a given set of datasets.
|
||||
|
||||
Args:
|
||||
datasets: A dictionary containing the
|
||||
`"dataset_subset_name": torch_dataset_object` key, value pairs.
|
||||
batch_size: The size of the batch of the dataloader.
|
||||
num_workers: Number data-loading threads.
|
||||
dataset_len: The number of batches in a training epoch.
|
||||
dataset_len_val: The number of batches in a validation epoch.
|
||||
images_per_seq_options: Possible numbers of images sampled per sequence.
|
||||
sample_consecutive_frames: if True, will sample a contiguous interval of frames
|
||||
in the sequence. It first sorts the frames by timestimps when available,
|
||||
otherwise by frame numbers, finds the connected segments within the sequence
|
||||
of sufficient length, then samples a random pivot element among them and
|
||||
ideally uses it as a middle of the temporal window, shifting the borders
|
||||
where necessary. This strategy mitigates the bias against shorter segments
|
||||
and their boundaries.
|
||||
consecutive_frames_max_gap: if a number > 0, then used to define the maximum
|
||||
difference in frame_number of neighbouring frames when forming connected
|
||||
segments; if both this and consecutive_frames_max_gap_seconds are 0s,
|
||||
the whole sequence is considered a segment regardless of frame numbers.
|
||||
consecutive_frames_max_gap_seconds: if a number > 0.0, then used to define the
|
||||
maximum difference in frame_timestamp of neighbouring frames when forming
|
||||
connected segments; if both this and consecutive_frames_max_gap are 0s,
|
||||
the whole sequence is considered a segment regardless of frame timestamps.
|
||||
|
||||
Returns:
|
||||
dataloaders: A dictionary containing the
|
||||
`"dataset_subset_name": torch_dataloader_object` key, value pairs.
|
||||
"""
|
||||
|
||||
dataloader_kwargs = {"num_workers": num_workers, "collate_fn": FrameData.collate}
|
||||
|
||||
def train_or_val_loader(
|
||||
dataset: Optional[ImplicitronDatasetBase], num_batches: int
|
||||
) -> Optional[torch.utils.data.DataLoader]:
|
||||
if dataset is None:
|
||||
return None
|
||||
batch_sampler = SceneBatchSampler(
|
||||
dataset,
|
||||
batch_size,
|
||||
num_batches=len(dataset) if num_batches <= 0 else num_batches,
|
||||
images_per_seq_options=images_per_seq_options,
|
||||
sample_consecutive_frames=sample_consecutive_frames,
|
||||
consecutive_frames_max_gap=consecutive_frames_max_gap,
|
||||
consecutive_frames_max_gap_seconds=consecutive_frames_max_gap_seconds,
|
||||
)
|
||||
return torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
**dataloader_kwargs,
|
||||
)
|
||||
|
||||
train_dataloader = train_or_val_loader(datasets.train, dataset_len)
|
||||
val_dataloader = train_or_val_loader(datasets.val, dataset_len_val)
|
||||
|
||||
test_dataset = datasets.test
|
||||
if test_dataset is not None:
|
||||
test_dataloader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_sampler=test_dataset.get_eval_batches(),
|
||||
**dataloader_kwargs,
|
||||
)
|
||||
else:
|
||||
test_dataloader = None
|
||||
|
||||
return Dataloaders(train=train_dataloader, val=val_dataloader, test=test_dataloader)
|
||||
|
||||
|
||||
enable_get_default_args(dataloader_zoo)
|
@ -1,14 +1,5 @@
|
||||
dataset_map_provider_class_type: ???
|
||||
dataloader_args:
|
||||
batch_size: 1
|
||||
num_workers: 0
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
sample_consecutive_frames: false
|
||||
consecutive_frames_max_gap: 0
|
||||
consecutive_frames_max_gap_seconds: 0.1
|
||||
data_loader_map_provider_class_type: SequenceDataLoaderMapProvider
|
||||
dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||
category: ???
|
||||
task_str: singlesequence
|
||||
@ -31,3 +22,13 @@ dataset_map_provider_JsonIndexDatasetMapProvider_args:
|
||||
image_height: 800
|
||||
remove_empty_masks: true
|
||||
path_manager: null
|
||||
data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
batch_size: 1
|
||||
num_workers: 0
|
||||
dataset_len: 1000
|
||||
dataset_len_val: 1
|
||||
images_per_seq_options:
|
||||
- 2
|
||||
sample_consecutive_frames: false
|
||||
consecutive_frames_max_gap: 0
|
||||
consecutive_frames_max_gap_seconds: 0.1
|
||||
|
Loading…
x
Reference in New Issue
Block a user