From d3b7f5f421e1f702b1c8946f065edbf8fa30297d Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Tue, 12 Jul 2022 07:20:21 -0700 Subject: [PATCH] fix trainer test Summary: After recent accelerate change D37543870 (https://github.com/facebookresearch/pytorch3d/commit/aa8b03f31dc2a178f8d7da457df28f19b5917009), update interactive trainer test. Reviewed By: shapovalov Differential Revision: D37785932 fbshipit-source-id: 9211374323b6cfd80f6c5ff3a4fc1c0ca04b54ba --- .../implicitron_trainer/tests/test_experiment.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/projects/implicitron_trainer/tests/test_experiment.py b/projects/implicitron_trainer/tests/test_experiment.py index 31368cc4..7d3da06f 100644 --- a/projects/implicitron_trainer/tests/test_experiment.py +++ b/projects/implicitron_trainer/tests/test_experiment.py @@ -8,7 +8,6 @@ import os import unittest from pathlib import Path -import torch from hydra import compose, initialize_config_dir from omegaconf import OmegaConf @@ -24,6 +23,9 @@ def interactive_testing_requested() -> bool: return os.environ.get("PYTORCH3D_INTERACTIVE_TESTING", "") == "1" +internal = os.environ.get("FB_TEST", False) + + DATA_DIR = Path(__file__).resolve().parent IMPLICITRON_CONFIGS_DIR = Path(__file__).resolve().parent.parent / "configs" DEBUG: bool = False @@ -40,7 +42,7 @@ class TestExperiment(unittest.TestCase): def test_from_defaults(self): # Test making minimal changes to the dataclass defaults. - if not interactive_testing_requested(): + if not interactive_testing_requested() or not internal: return cfg = OmegaConf.structured(experiment.ExperimentConfig) cfg.data_source_args.dataset_map_provider_class_type = ( @@ -56,11 +58,13 @@ class TestExperiment(unittest.TestCase): dataset_args.test_restrict_sequence_id = 0 dataset_args.dataset_root = "manifold://co3d/tree/extracted" dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 5 + dataset_args.dataset_JsonIndexDataset_args.image_height = 80 + dataset_args.dataset_JsonIndexDataset_args.image_width = 80 dataloader_args.dataset_length_train = 1 + dataloader_args.dataset_length_val = 1 cfg.solver_args.max_epochs = 2 - device = torch.device("cuda:0") - experiment.run_training(cfg, device) + experiment.run_training(cfg) def test_yaml_contents(self): cfg = OmegaConf.structured(experiment.ExperimentConfig)