SimpleDataLoaderMapProvider

Summary: Simple DataLoaderMapProvider instance

Reviewed By: davnov134

Differential Revision: D38326719

fbshipit-source-id: 58556833e76fae5790d25a59bea0aac4ce046bf1
This commit is contained in:
Jeremy Reizenstein 2022-08-02 12:10:05 -07:00 committed by Facebook GitHub Bot
parent c63ec81750
commit 3a063f5976
4 changed files with 121 additions and 1 deletions

View File

@ -117,6 +117,12 @@ data_source_ImplicitronDataSource_args:
sample_consecutive_frames: false
consecutive_frames_max_gap: 0
consecutive_frames_max_gap_seconds: 0.1
data_loader_map_provider_SimpleDataLoaderMapProvider_args:
batch_size: 1
num_workers: 0
dataset_length_train: 0
dataset_length_val: 0
dataset_length_test: 0
model_factory_ImplicitronModelFactory_args:
force_load: false
model_class_type: GenericModel

View File

@ -61,6 +61,84 @@ class DataLoaderMapProviderBase(ReplaceableBase):
raise NotImplementedError()
@registry.register
class SimpleDataLoaderMapProvider(DataLoaderMapProviderBase):
"""
Trivial implementation of DataLoaderMapProviderBase.
If a dataset returns batches from get_eval_batches(), then
they will be what the corresponding dataloader returns,
independently of any of the fields on this class.
Otherwise, returns shuffled batches.
"""
batch_size: int = 1
num_workers: int = 0
dataset_length_train: int = 0
dataset_length_val: int = 0
dataset_length_test: int = 0
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
"""
Returns a collection of data loaders for a given collection of datasets.
"""
return DataLoaderMap(
train=self._make_data_loader(
datasets.train,
self.dataset_length_train,
),
val=self._make_data_loader(
datasets.val,
self.dataset_length_val,
),
test=self._make_data_loader(
datasets.test,
self.dataset_length_test,
),
)
def _make_data_loader(
self,
dataset: Optional[DatasetBase],
num_batches: int,
) -> Optional[DataLoader[FrameData]]:
"""
Returns the dataloader for a dataset.
Args:
dataset: the dataset
num_batches: possible ceiling on number of batches per epoch
"""
if dataset is None:
return None
data_loader_kwargs = {
"num_workers": self.num_workers,
"collate_fn": dataset.frame_data_type.collate,
}
eval_batches = dataset.get_eval_batches()
if eval_batches is not None:
return DataLoader(
dataset,
batch_sampler=eval_batches,
**data_loader_kwargs,
)
if num_batches > 0:
num_samples = self.batch_size * num_batches
else:
num_samples = None
sampler = RandomSampler(dataset, replacement=True, num_samples=num_samples)
batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True)
return DataLoader(
dataset,
batch_sampler=batch_sampler,
**data_loader_kwargs,
)
class DoublePoolBatchSampler(Sampler[List[int]]):
"""
Batch sampler for making random batches of a single frame

View File

@ -105,3 +105,9 @@ data_loader_map_provider_SequenceDataLoaderMapProvider_args:
sample_consecutive_frames: false
consecutive_frames_max_gap: 0
consecutive_frames_max_gap_seconds: 0.1
data_loader_map_provider_SimpleDataLoaderMapProvider_args:
batch_size: 1
num_workers: 0
dataset_length_train: 0
dataset_length_val: 0
dataset_length_test: 0

View File

@ -10,6 +10,10 @@ import unittest.mock
import torch
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
SequenceDataLoaderMapProvider,
SimpleDataLoaderMapProvider,
)
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.tools.config import get_default_args
@ -64,7 +68,6 @@ class TestDataSource(unittest.TestCase):
return
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "SequenceDataLoaderMapProvider"
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
dataset_args.category = "skateboard"
dataset_args.test_restrict_sequence_id = 0
@ -73,8 +76,35 @@ class TestDataSource(unittest.TestCase):
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
data_source = ImplicitronDataSource(**args)
self.assertIsInstance(
data_source.data_loader_map_provider, SequenceDataLoaderMapProvider
)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 81)
for i in data_loaders.train:
self.assertEqual(i.frame_type, ["test_known"])
break
def test_simple(self):
if os.environ.get("INSIDE_RE_WORKER") is not None:
return
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "SimpleDataLoaderMapProvider"
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
dataset_args.category = "skateboard"
dataset_args.test_restrict_sequence_id = 0
dataset_args.n_frames_per_sequence = -1
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
data_source = ImplicitronDataSource(**args)
self.assertIsInstance(
data_source.data_loader_map_provider, SimpleDataLoaderMapProvider
)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 81)
for i in data_loaders.train:
self.assertEqual(i.frame_type, ["test_known"])
break