mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	fix trainer test
Summary: After recent accelerate change D37543870 (aa8b03f31d), update interactive trainer test.
Reviewed By: shapovalov
Differential Revision: D37785932
fbshipit-source-id: 9211374323b6cfd80f6c5ff3a4fc1c0ca04b54ba
			
			
This commit is contained in:
		
							parent
							
								
									4ecc9ea89d
								
							
						
					
					
						commit
						d3b7f5f421
					
				@ -8,7 +8,6 @@ import os
 | 
			
		||||
import unittest
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from hydra import compose, initialize_config_dir
 | 
			
		||||
from omegaconf import OmegaConf
 | 
			
		||||
 | 
			
		||||
@ -24,6 +23,9 @@ def interactive_testing_requested() -> bool:
 | 
			
		||||
    return os.environ.get("PYTORCH3D_INTERACTIVE_TESTING", "") == "1"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
internal = os.environ.get("FB_TEST", False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
DATA_DIR = Path(__file__).resolve().parent
 | 
			
		||||
IMPLICITRON_CONFIGS_DIR = Path(__file__).resolve().parent.parent / "configs"
 | 
			
		||||
DEBUG: bool = False
 | 
			
		||||
@ -40,7 +42,7 @@ class TestExperiment(unittest.TestCase):
 | 
			
		||||
 | 
			
		||||
    def test_from_defaults(self):
 | 
			
		||||
        # Test making minimal changes to the dataclass defaults.
 | 
			
		||||
        if not interactive_testing_requested():
 | 
			
		||||
        if not interactive_testing_requested() or not internal:
 | 
			
		||||
            return
 | 
			
		||||
        cfg = OmegaConf.structured(experiment.ExperimentConfig)
 | 
			
		||||
        cfg.data_source_args.dataset_map_provider_class_type = (
 | 
			
		||||
@ -56,11 +58,13 @@ class TestExperiment(unittest.TestCase):
 | 
			
		||||
        dataset_args.test_restrict_sequence_id = 0
 | 
			
		||||
        dataset_args.dataset_root = "manifold://co3d/tree/extracted"
 | 
			
		||||
        dataset_args.dataset_JsonIndexDataset_args.limit_sequences_to = 5
 | 
			
		||||
        dataset_args.dataset_JsonIndexDataset_args.image_height = 80
 | 
			
		||||
        dataset_args.dataset_JsonIndexDataset_args.image_width = 80
 | 
			
		||||
        dataloader_args.dataset_length_train = 1
 | 
			
		||||
        dataloader_args.dataset_length_val = 1
 | 
			
		||||
        cfg.solver_args.max_epochs = 2
 | 
			
		||||
 | 
			
		||||
        device = torch.device("cuda:0")
 | 
			
		||||
        experiment.run_training(cfg, device)
 | 
			
		||||
        experiment.run_training(cfg)
 | 
			
		||||
 | 
			
		||||
    def test_yaml_contents(self):
 | 
			
		||||
        cfg = OmegaConf.structured(experiment.ExperimentConfig)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user