SimpleDataLoaderMapProvider

Summary: Simple DataLoaderMapProvider instance

Reviewed By: davnov134

Differential Revision: D38326719

fbshipit-source-id: 58556833e76fae5790d25a59bea0aac4ce046bf1
This commit is contained in:
Jeremy Reizenstein
2022-08-02 12:10:05 -07:00
committed by Facebook GitHub Bot
parent c63ec81750
commit 3a063f5976
4 changed files with 121 additions and 1 deletions

View File

@@ -105,3 +105,9 @@ data_loader_map_provider_SequenceDataLoaderMapProvider_args:
sample_consecutive_frames: false
consecutive_frames_max_gap: 0
consecutive_frames_max_gap_seconds: 0.1
data_loader_map_provider_SimpleDataLoaderMapProvider_args:
batch_size: 1
num_workers: 0
dataset_length_train: 0
dataset_length_val: 0
dataset_length_test: 0

View File

@@ -10,6 +10,10 @@ import unittest.mock
import torch
from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_loader_map_provider import (
SequenceDataLoaderMapProvider,
SimpleDataLoaderMapProvider,
)
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.json_index_dataset import JsonIndexDataset
from pytorch3d.implicitron.tools.config import get_default_args
@@ -64,7 +68,6 @@ class TestDataSource(unittest.TestCase):
return
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "SequenceDataLoaderMapProvider"
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
dataset_args.category = "skateboard"
dataset_args.test_restrict_sequence_id = 0
@@ -73,8 +76,35 @@ class TestDataSource(unittest.TestCase):
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
data_source = ImplicitronDataSource(**args)
self.assertIsInstance(
data_source.data_loader_map_provider, SequenceDataLoaderMapProvider
)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 81)
for i in data_loaders.train:
self.assertEqual(i.frame_type, ["test_known"])
break
def test_simple(self):
if os.environ.get("INSIDE_RE_WORKER") is not None:
return
args = get_default_args(ImplicitronDataSource)
args.dataset_map_provider_class_type = "JsonIndexDatasetMapProvider"
args.data_loader_map_provider_class_type = "SimpleDataLoaderMapProvider"
dataset_args = args.dataset_map_provider_JsonIndexDatasetMapProvider_args
dataset_args.category = "skateboard"
dataset_args.test_restrict_sequence_id = 0
dataset_args.n_frames_per_sequence = -1
dataset_args.dataset_root = "manifold://co3d/tree/extracted"
data_source = ImplicitronDataSource(**args)
self.assertIsInstance(
data_source.data_loader_map_provider, SimpleDataLoaderMapProvider
)
_, data_loaders = data_source.get_datasets_and_dataloaders()
self.assertEqual(len(data_loaders.train), 81)
for i in data_loaders.train:
self.assertEqual(i.frame_type, ["test_known"])
break