diff --git a/projects/implicitron_trainer/impl/optimizer_factory.py b/projects/implicitron_trainer/impl/optimizer_factory.py index 184adb92..9e4a5227 100644 --- a/projects/implicitron_trainer/impl/optimizer_factory.py +++ b/projects/implicitron_trainer/impl/optimizer_factory.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import inspect import logging import os from typing import Any, Dict, Optional, Tuple @@ -61,6 +62,8 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase): increasing epoch indices at which the learning rate is modified. momentum: Momentum factor for SGD optimizer. weight_decay: The optimizer weight_decay (L2 penalty on model weights). + foreach: Whether to use new "foreach" implementation of optimizer where + available (e.g. requires PyTorch 1.12.0 for Adam) """ betas: Tuple[float, ...] = (0.9, 0.999) @@ -74,6 +77,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase): weight_decay: float = 0.0 linear_exponential_lr_milestone: int = 200 linear_exponential_start_gamma: float = 0.1 + foreach: Optional[bool] = True def __post_init__(self): run_auto_creation(self) @@ -115,23 +119,24 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase): p_groups = [{"params": allprm, "lr": self.lr}] # Intialize the optimizer + optimizer_kwargs: Dict[str, Any] = { + "lr": self.lr, + "weight_decay": self.weight_decay, + } if self.breed == "SGD": - optimizer = torch.optim.SGD( - p_groups, - lr=self.lr, - momentum=self.momentum, - weight_decay=self.weight_decay, - ) + optimizer_class = torch.optim.SGD + optimizer_kwargs["momentum"] = self.momentum elif self.breed == "Adagrad": - optimizer = torch.optim.Adagrad( - p_groups, lr=self.lr, weight_decay=self.weight_decay - ) + optimizer_class = torch.optim.Adagrad elif self.breed == "Adam": - optimizer = torch.optim.Adam( - p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay - ) + optimizer_class = torch.optim.Adam + optimizer_kwargs["betas"] = self.betas else: raise ValueError(f"No such solver type {self.breed}") + + if "foreach" in inspect.signature(optimizer_class.__init__).parameters: + optimizer_kwargs["foreach"] = self.foreach + optimizer = optimizer_class(p_groups, **optimizer_kwargs) logger.info(f"Solver type = {self.breed}") # Load state from checkpoint diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index 9588043b..bd52beac 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -406,6 +406,7 @@ optimizer_factory_ImplicitronOptimizerFactory_args: weight_decay: 0.0 linear_exponential_lr_milestone: 200 linear_exponential_start_gamma: 0.1 + foreach: true training_loop_ImplicitronTrainingLoop_args: evaluator_class_type: ImplicitronEvaluator evaluator_ImplicitronEvaluator_args: diff --git a/projects/implicitron_trainer/tests/test_experiment.py b/projects/implicitron_trainer/tests/test_experiment.py index 1c261e7a..d16c788d 100644 --- a/projects/implicitron_trainer/tests/test_experiment.py +++ b/projects/implicitron_trainer/tests/test_experiment.py @@ -9,13 +9,17 @@ import tempfile import unittest from pathlib import Path +import torch + from hydra import compose, initialize_config_dir from omegaconf import OmegaConf +from projects.implicitron_trainer.impl.optimizer_factory import ( + ImplicitronOptimizerFactory, +) from .. import experiment from .utils import interactive_testing_requested, intercept_logs - internal = os.environ.get("FB_TEST", False) @@ -151,6 +155,16 @@ class TestExperiment(unittest.TestCase): with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)): compose(file.name) + def test_optimizer_factory(self): + model = torch.nn.Linear(2, 2) + + adam, sched = ImplicitronOptimizerFactory(breed="Adam")(0, model) + self.assertIsInstance(adam, torch.optim.Adam) + sgd, sched = ImplicitronOptimizerFactory(breed="SGD")(0, model) + self.assertIsInstance(sgd, torch.optim.SGD) + adagrad, sched = ImplicitronOptimizerFactory(breed="Adagrad")(0, model) + self.assertIsInstance(adagrad, torch.optim.Adagrad) + class TestNerfRepro(unittest.TestCase): @unittest.skip("This test runs full blender training.")