From cd5db076d5c849494b86e9404f2a30ff5aa7f0e8 Mon Sep 17 00:00:00 2001 From: Roman Shapovalov Date: Wed, 14 Jun 2023 10:51:47 -0700 Subject: [PATCH] Adding SQL dataset classes to ImplicitronDataSource imports Summary: Making it easier for the clients to use these datasets. Reviewed By: bottler Differential Revision: D46727179 fbshipit-source-id: cf619aee4c4c0222a74b30ea590cf37f08f014cc --- projects/implicitron_trainer/tests/experiment.yaml | 13 +++++++++++++ .../implicitron_trainer/tests/test_experiment.py | 2 ++ pytorch3d/implicitron/dataset/data_source.py | 10 ++++++++++ tests/implicitron/data/data_source.yaml | 13 +++++++++++++ tests/implicitron/test_data_source.py | 3 +++ 5 files changed, 41 insertions(+) diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index f2df83e5..d217c867 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -129,6 +129,19 @@ data_source_ImplicitronDataSource_args: dataset_length_train: 0 dataset_length_val: 0 dataset_length_test: 0 + data_loader_map_provider_TrainEvalDataLoaderMapProvider_args: + batch_size: 1 + num_workers: 0 + dataset_length_train: 0 + dataset_length_val: 0 + dataset_length_test: 0 + train_conditioning_type: SAME + val_conditioning_type: SAME + test_conditioning_type: KNOWN + images_per_seq_options: [] + sample_consecutive_frames: false + consecutive_frames_max_gap: 0 + consecutive_frames_max_gap_seconds: 0.1 model_factory_ImplicitronModelFactory_args: resume: true model_class_type: GenericModel diff --git a/projects/implicitron_trainer/tests/test_experiment.py b/projects/implicitron_trainer/tests/test_experiment.py index 590102fa..486d2134 100644 --- a/projects/implicitron_trainer/tests/test_experiment.py +++ b/projects/implicitron_trainer/tests/test_experiment.py @@ -136,6 +136,8 @@ class TestExperiment(unittest.TestCase): ds_arg = cfg.data_source_ImplicitronDataSource_args ds_arg.dataset_map_provider_JsonIndexDatasetMapProvider_args.dataset_root = "" ds_arg.dataset_map_provider_JsonIndexDatasetMapProviderV2_args.dataset_root = "" + if "dataset_map_provider_SqlIndexDatasetMapProvider_args" in ds_arg: + del ds_arg.dataset_map_provider_SqlIndexDatasetMapProvider_args cfg.training_loop_ImplicitronTrainingLoop_args.visdom_port = 8097 yaml = OmegaConf.to_yaml(cfg, sort_keys=False) if DEBUG: diff --git a/pytorch3d/implicitron/dataset/data_source.py b/pytorch3d/implicitron/dataset/data_source.py index 7ea5fe6f..a7989ac9 100644 --- a/pytorch3d/implicitron/dataset/data_source.py +++ b/pytorch3d/implicitron/dataset/data_source.py @@ -72,6 +72,16 @@ class ImplicitronDataSource(DataSourceBase): # pyre-ignore[13] from .rendered_mesh_dataset_map_provider import ( # noqa: F401 RenderedMeshDatasetMapProvider, ) + from .train_eval_data_loader_provider import ( # noqa: F401 + TrainEvalDataLoaderMapProvider, + ) + + try: + from .sql_dataset_provider import ( # noqa: F401 # pyre-ignore + SqlIndexDatasetMapProvider, + ) + except ModuleNotFoundError: + pass # environment without SQL dataset finally: pass diff --git a/tests/implicitron/data/data_source.yaml b/tests/implicitron/data/data_source.yaml index 291c5504..a444309c 100644 --- a/tests/implicitron/data/data_source.yaml +++ b/tests/implicitron/data/data_source.yaml @@ -116,3 +116,16 @@ data_loader_map_provider_SimpleDataLoaderMapProvider_args: dataset_length_train: 0 dataset_length_val: 0 dataset_length_test: 0 +data_loader_map_provider_TrainEvalDataLoaderMapProvider_args: + batch_size: 1 + num_workers: 0 + dataset_length_train: 0 + dataset_length_val: 0 + dataset_length_test: 0 + train_conditioning_type: SAME + val_conditioning_type: SAME + test_conditioning_type: KNOWN + images_per_seq_options: [] + sample_consecutive_frames: false + consecutive_frames_max_gap: 0 + consecutive_frames_max_gap_seconds: 0.1 diff --git a/tests/implicitron/test_data_source.py b/tests/implicitron/test_data_source.py index f289fef2..67c79fc5 100644 --- a/tests/implicitron/test_data_source.py +++ b/tests/implicitron/test_data_source.py @@ -68,6 +68,9 @@ class TestDataSource(unittest.TestCase): # making the test invariant to env variables cfg.dataset_map_provider_JsonIndexDatasetMapProvider_args.dataset_root = "" cfg.dataset_map_provider_JsonIndexDatasetMapProviderV2_args.dataset_root = "" + # making the test invariant to the presence of SQL dataset + if "dataset_map_provider_SqlIndexDatasetMapProvider_args" in cfg: + del cfg.dataset_map_provider_SqlIndexDatasetMapProvider_args yaml = OmegaConf.to_yaml(cfg, sort_keys=False) if DEBUG: (DATA_DIR / "data_source.yaml").write_text(yaml)