mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
SimpleDataLoaderMapProvider
Summary: Simple DataLoaderMapProvider instance Reviewed By: davnov134 Differential Revision: D38326719 fbshipit-source-id: 58556833e76fae5790d25a59bea0aac4ce046bf1
This commit is contained in:
parent
c63ec81750
commit
3a063f5976
@ -117,6 +117,12 @@ data_source_ImplicitronDataSource_args:
|
|||||||
sample_consecutive_frames: false
|
sample_consecutive_frames: false
|
||||||
consecutive_frames_max_gap: 0
|
consecutive_frames_max_gap: 0
|
||||||
consecutive_frames_max_gap_seconds: 0.1
|
consecutive_frames_max_gap_seconds: 0.1
|
||||||
|
data_loader_map_provider_SimpleDataLoaderMapProvider_args:
|
||||||
|
batch_size: 1
|
||||||
|
num_workers: 0
|
||||||
|
dataset_length_train: 0
|
||||||
|
dataset_length_val: 0
|
||||||
|
dataset_length_test: 0
|
||||||
model_factory_ImplicitronModelFactory_args:
|
model_factory_ImplicitronModelFactory_args:
|
||||||
force_load: false
|
force_load: false
|
||||||
model_class_type: GenericModel
|
model_class_type: GenericModel
|
||||||
|
@ -61,6 +61,84 @@ class DataLoaderMapProviderBase(ReplaceableBase):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class SimpleDataLoaderMapProvider(DataLoaderMapProviderBase):
|
||||||
|
"""
|
||||||
|
Trivial implementation of DataLoaderMapProviderBase.
|
||||||
|
|
||||||
|
If a dataset returns batches from get_eval_batches(), then
|
||||||
|
they will be what the corresponding dataloader returns,
|
||||||
|
independently of any of the fields on this class.
|
||||||
|
|
||||||
|
Otherwise, returns shuffled batches.
|
||||||
|
"""
|
||||||
|
|
||||||
|
batch_size: int = 1
|
||||||
|
num_workers: int = 0
|
||||||
|
dataset_length_train: int = 0
|
||||||
|
dataset_length_val: int = 0
|
||||||
|
dataset_length_test: int = 0
|
||||||
|
|
||||||
|
def get_data_loader_map(self, datasets: DatasetMap) -> DataLoaderMap:
|
||||||
|
"""
|
||||||
|
Returns a collection of data loaders for a given collection of datasets.
|
||||||
|
"""
|
||||||
|
return DataLoaderMap(
|
||||||
|
train=self._make_data_loader(
|
||||||
|
datasets.train,
|
||||||
|
self.dataset_length_train,
|
||||||
|
),
|
||||||
|
val=self._make_data_loader(
|
||||||
|
datasets.val,
|
||||||
|
self.dataset_length_val,
|
||||||
|
),
|
||||||
|
test=self._make_data_loader(
|
||||||
|
datasets.test,
|
||||||
|
self.dataset_length_test,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_data_loader(
|
||||||
|
self,
|
||||||
|
dataset: Optional[DatasetBase],
|
||||||
|
num_batches: int,
|
||||||
|
) -> Optional[DataLoader[FrameData]]:
|
||||||
|
"""
|
||||||
|
Returns the dataloader for a dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset: the dataset
|
||||||
|
num_batches: possible ceiling on number of batches per epoch
|
||||||
|
"""
|
||||||
|
if dataset is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data_loader_kwargs = {
|
||||||
|
"num_workers": self.num_workers,
|
||||||
|
"collate_fn": dataset.frame_data_type.collate,
|
||||||
|
}
|
||||||
|
|
||||||
|
eval_batches = dataset.get_eval_batches()
|
||||||
|
if eval_batches is not None:
|
||||||
|
return DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_sampler=eval_batches,
|
||||||
|
**data_loader_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if num_batches > 0:
|
||||||
|
num_samples = self.batch_size * num_batches
|
||||||
|
else:
|
||||||
|
num_samples = None
|
||||||
|
sampler = RandomSampler(dataset, replacement=True, num_samples=num_samples)
|
||||||
|
batch_sampler = BatchSampler(sampler, self.batch_size, drop_last=True)
|
||||||
|
return DataLoader(
|
||||||
|
dataset,
|
||||||
|
batch_sampler=batch_sampler,
|
||||||
|
**data_loader_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DoublePoolBatchSampler(Sampler[List[int]]):
|
class DoublePoolBatchSampler(Sampler[List[int]]):
|
||||||
"""
|
"""
|
||||||
Batch sampler for making random batches of a single frame
|
Batch sampler for making random batches of a single frame
|
||||||
|
@ -105,3 +105,9 @@ data_loader_map_provider_SequenceDataLoaderMapProvider_args:
|
|||||||
sample_consecutive_frames: false
|
sample_consecutive_frames: false
|
||||||
consecutive_frames_max_gap: 0
|
consecutive_frames_max_gap: 0
|
||||||
consecutive_frames_max_gap_seconds: 0.1
|
consecutive_frames_max_gap_seconds: 0.1
|
||||||
|
data_loader_map_provider_SimpleDataLoaderMapProvider_args:
|
||||||
|
batch_size: 1
|
||||||
|
num_workers: 0
|
||||||
|
dataset_length_train: 0
|
||||||
|
dataset_length_val: 0
|
||||||
|
dataset_length_test: 0
|
||||||
|
@ -10,6 +10,10 @@ import unittest.mock
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
|
||||||
|
SequenceDataLoaderMapProvider,
|
||||||
|
SimpleDataLoaderMapProvider,
|
||||||
|
)
|
||||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||||
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
|
||||||
from pytorch3d.implicitron.tools.config import get_default_args
|
from pytorch3d.implicitron.tools.config import get_default_args
|
||||||
@ -64,7 +68,6 @@ class TestDataSource(unittest.TestCase):
|
|||||||
return
|
return
|
||||||
args = get_default_args(ImplicitronDataSource)
|
args = get_default_args(ImplicitronDataSource)
|
||||||
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
|
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
|
||||||
args.data_loader_map_provider_class_type = "SequenceDataLoaderMapProvider"
|
|
||||||
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||||
dataset_args.category = "skateboard"
|
dataset_args.category = "skateboard"
|
||||||
dataset_args.test_restrict_sequence_id = 0
|
dataset_args.test_restrict_sequence_id = 0
|
||||||
@ -73,8 +76,35 @@ class TestDataSource(unittest.TestCase):
|
|||||||
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||||
|
|
||||||
data_source = ImplicitronDataSource(**args)
|
data_source = ImplicitronDataSource(**args)
|
||||||
|
self.assertIsInstance(
|
||||||
|
data_source.data_loader_map_provider, SequenceDataLoaderMapProvider
|
||||||
|
)
|
||||||
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||||
self.assertEqual(len(data_loaders.train), 81)
|
self.assertEqual(len(data_loaders.train), 81)
|
||||||
for i in data_loaders.train:
|
for i in data_loaders.train:
|
||||||
self.assertEqual(i.frame_type, ["test_known"])
|
self.assertEqual(i.frame_type, ["test_known"])
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def test_simple(self):
|
||||||
|
if os.environ.get("INSIDE_RE_WORKER") is not None:
|
||||||
|
return
|
||||||
|
args = get_default_args(ImplicitronDataSource)
|
||||||
|
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
|
||||||
|
args.data_loader_map_provider_class_type = "SimpleDataLoaderMapProvider"
|
||||||
|
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
|
||||||
|
dataset_args.category = "skateboard"
|
||||||
|
dataset_args.test_restrict_sequence_id = 0
|
||||||
|
dataset_args.n_frames_per_sequence = -1
|
||||||
|
|
||||||
|
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
|
||||||
|
|
||||||
|
data_source = ImplicitronDataSource(**args)
|
||||||
|
self.assertIsInstance(
|
||||||
|
data_source.data_loader_map_provider, SimpleDataLoaderMapProvider
|
||||||
|
)
|
||||||
|
_, data_loaders = data_source.get_datasets_and_dataloaders()
|
||||||
|
|
||||||
|
self.assertEqual(len(data_loaders.train), 81)
|
||||||
|
for i in data_loaders.train:
|
||||||
|
self.assertEqual(i.frame_type, ["test_known"])
|
||||||
|
break
|
||||||
|
Loading…
x
Reference in New Issue
Block a user