mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
SimpleDataLoaderMapProvider
Summary: Simple DataLoaderMapProvider instance Reviewed By: davnov134 Differential Revision: D38326719 fbshipit-source-id: 58556833e76fae5790d25a59bea0aac4ce046bf1
This commit is contained in:
parent
c63ec81750
commit
3a063f5976
@ -117,6 +117,12 @@ data_source_ImplicitronDataSource_args:
|
||||
sample_consecutive_frames: false
|
||||
consecutive_frames_max_gap: 0
|
||||
consecutive_frames_max_gap_seconds: 0.1
|
||||
data_loader_map_provider_SimpleDataLoaderMapProvider_args:
|
||||
batch_size: 1
|
||||
num_workers: 0
|
||||
dataset_length_train: 0
|
||||
dataset_length_val: 0
|
||||
dataset_length_test: 0
|
||||
model_factory_ImplicitronModelFactory_args:
|
||||
force_load: false
|
||||
model_class_type: GenericModel
|
||||
|
@ -61,6 +61,84 @@ class DataLoaderMapProviderBase(ReplaceableBase):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class SimpleDataLoaderMapProvider(DataLoaderMapProviderBase):
|
||||
"""
|
||||
Trivial implementation of DataLoaderMapProviderBase.
|
||||
|
||||
If a dataset returns batches from get_eval_batches(), then
|
||||
they will be what the corresponding dataloader returns,
|
||||
independently of any of the fields on this class.
|
||||
|
||||
Otherwise, returns shuffled batches.
|
||||
"""
|
||||
|
||||
batch_size: int = 1
|
||||
num_workers: int = 0
|
||||
dataset_length_train: int = 0
|
||||
dataset_length_val: int = 0
|
||||
dataset_length_test: int = 0
|
||||
|
||||
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
|
||||
"""
|
||||
Returns a collection of data loaders for a given collection of datasets.
|
||||
"""
|
||||
return DataLoaderMap(
|
||||
train=self._make_data_loader(
|
||||
datasets.train,
|
||||
self.dataset_length_train,
|
||||
),
|
||||
val=self._make_data_loader(
|
||||
datasets.val,
|
||||
self.dataset_length_val,
|
||||
),
|
||||
test=self._make_data_loader(
|
||||
datasets.test,
|
||||
self.dataset_length_test,
|
||||
),
|
||||
)
|
||||
|
||||
def _make_data_loader(
|
||||
self,
|
||||
dataset: Optional[DatasetBase],
|
||||
num_batches: int,
|
||||
) -> Optional[DataLoader[FrameData]]:
|
||||
"""
|
||||
Returns the dataloader for a dataset.
|
||||
|
||||
Args:
|
||||
dataset: the dataset
|
||||
num_batches: possible ceiling on number of batches per epoch
|
||||
"""
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
data_loader_kwargs = {
|
||||
"num_workers": self.num_workers,
|
||||
"collate_fn": dataset.frame_data_type.collate,
|
||||
}
|
||||
|
||||
eval_batches = dataset.get_eval_batches()
|
||||
if eval_batches is not None:
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_sampler=eval_batches,
|
||||
**data_loader_kwargs,
|
||||
)
|
||||
|
||||
if num_batches > 0:
|
||||
num_samples = self.batch_size * num_batches
|
||||
else:
|
||||
num_samples = None
|
||||
sampler = RandomSampler(dataset, replacement=True, num_samples=num_samples)
|
||||
batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
**data_loader_kwargs,
|
||||
)
|
||||
|
||||
|
||||
class DoublePoolBatchSampler(Sampler[List[int]]):
|
||||
"""
|
||||
Batch sampler for making random batches of a single frame
|
||||
|
@ -105,3 +105,9 @@ data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
||||
sample_consecutive_frames: false
|
||||
consecutive_frames_max_gap: 0
|
||||
consecutive_frames_max_gap_seconds: 0.1
|
||||
data_loader_map_provider_SimpleDataLoaderMapProvider_args:
|
||||
batch_size: 1
|
||||
num_workers: 0
|
||||
dataset_length_train: 0
|
||||
dataset_length_val: 0
|
||||
dataset_length_test: 0
|
||||
|
@ -10,6 +10,10 @@ import unittest.mock
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
||||
SequenceDataLoaderMapProvider,
|
||||
SimpleDataLoaderMapProvider,
|
||||
)
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||
from pytorch3d.implicitron.tools.config import get_default_args
|
||||
@ -64,7 +68,6 @@ class TestDataSource(unittest.TestCase):
|
||||
return
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "SequenceDataLoaderMapProvider"
|
||||
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
dataset_args.category = "skateboard"
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
@ -73,8 +76,35 @@ class TestDataSource(unittest.TestCase):
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
self.assertIsInstance(
|
||||
data_source.data_loader_map_provider, SequenceDataLoaderMapProvider
|
||||
)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
self.assertEqual(len(data_loaders.train), 81)
|
||||
for i in data_loaders.train:
|
||||
self.assertEqual(i.frame_type, ["test_known"])
|
||||
break
|
||||
|
||||
def test_simple(self):
|
||||
if os.environ.get("INSIDE_RE_WORKER") is not None:
|
||||
return
|
||||
args = get_default_args(ImplicitronDataSource)
|
||||
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
|
||||
args.data_loader_map_provider_class_type = "SimpleDataLoaderMapProvider"
|
||||
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||
dataset_args.category = "skateboard"
|
||||
dataset_args.test_restrict_sequence_id = 0
|
||||
dataset_args.n_frames_per_sequence = -1
|
||||
|
||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||
|
||||
data_source = ImplicitronDataSource(**args)
|
||||
self.assertIsInstance(
|
||||
data_source.data_loader_map_provider, SimpleDataLoaderMapProvider
|
||||
)
|
||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||
|
||||
self.assertEqual(len(data_loaders.train), 81)
|
||||
for i in data_loaders.train:
|
||||
self.assertEqual(i.frame_type, ["test_known"])
|
||||
break
|
||||
|
Loading…
x
Reference in New Issue
Block a user