foreach optimizers

Summary: Allow using the new `foreach` option on optimizers.

Reviewed By: shapovalov

Differential Revision: D39694843

fbshipit-source-id: 97109c245b669bc6edff0f246893f95b7ae71f90
This commit is contained in:
Jeremy Reizenstein 2022-09-22 05:11:56 -07:00 committed by Facebook GitHub Bot
parent db3c12abfb
commit 209c160a20
3 changed files with 33 additions and 13 deletions

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import inspect
import logging import logging
import os import os
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
@ -61,6 +62,8 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
increasing epoch indices at which the learning rate is modified. increasing epoch indices at which the learning rate is modified.
momentum: Momentum factor for SGD optimizer. momentum: Momentum factor for SGD optimizer.
weight_decay: The optimizer weight_decay (L2 penalty on model weights). weight_decay: The optimizer weight_decay (L2 penalty on model weights).
foreach: Whether to use new "foreach" implementation of optimizer where
available (e.g. requires PyTorch 1.12.0 for Adam)
""" """
betas: Tuple[float, ...] = (0.9, 0.999) betas: Tuple[float, ...] = (0.9, 0.999)
@ -74,6 +77,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
weight_decay: float = 0.0 weight_decay: float = 0.0
linear_exponential_lr_milestone: int = 200 linear_exponential_lr_milestone: int = 200
linear_exponential_start_gamma: float = 0.1 linear_exponential_start_gamma: float = 0.1
foreach: Optional[bool] = True
def __post_init__(self): def __post_init__(self):
run_auto_creation(self) run_auto_creation(self)
@ -115,23 +119,24 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
p_groups = [{"params": allprm, "lr": self.lr}] p_groups = [{"params": allprm, "lr": self.lr}]
# Intialize the optimizer # Intialize the optimizer
optimizer_kwargs: Dict[str, Any] = {
"lr": self.lr,
"weight_decay": self.weight_decay,
}
if self.breed == "SGD": if self.breed == "SGD":
optimizer = torch.optim.SGD( optimizer_class = torch.optim.SGD
p_groups, optimizer_kwargs["momentum"] = self.momentum
lr=self.lr,
momentum=self.momentum,
weight_decay=self.weight_decay,
)
elif self.breed == "Adagrad": elif self.breed == "Adagrad":
optimizer = torch.optim.Adagrad( optimizer_class = torch.optim.Adagrad
p_groups, lr=self.lr, weight_decay=self.weight_decay
)
elif self.breed == "Adam": elif self.breed == "Adam":
optimizer = torch.optim.Adam( optimizer_class = torch.optim.Adam
p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay optimizer_kwargs["betas"] = self.betas
)
else: else:
raise ValueError(f"No such solver type {self.breed}") raise ValueError(f"No such solver type {self.breed}")
if "foreach" in inspect.signature(optimizer_class.__init__).parameters:
optimizer_kwargs["foreach"] = self.foreach
optimizer = optimizer_class(p_groups, **optimizer_kwargs)
logger.info(f"Solver type = {self.breed}") logger.info(f"Solver type = {self.breed}")
# Load state from checkpoint # Load state from checkpoint

View File

@ -406,6 +406,7 @@ optimizer_factory_ImplicitronOptimizerFactory_args:
weight_decay: 0.0 weight_decay: 0.0
linear_exponential_lr_milestone: 200 linear_exponential_lr_milestone: 200
linear_exponential_start_gamma: 0.1 linear_exponential_start_gamma: 0.1
foreach: true
training_loop_ImplicitronTrainingLoop_args: training_loop_ImplicitronTrainingLoop_args:
evaluator_class_type: ImplicitronEvaluator evaluator_class_type: ImplicitronEvaluator
evaluator_ImplicitronEvaluator_args: evaluator_ImplicitronEvaluator_args:

View File

@ -9,13 +9,17 @@ import tempfile
import unittest import unittest
from pathlib import Path from pathlib import Path
import torch
from hydra import compose, initialize_config_dir from hydra import compose, initialize_config_dir
from omegaconf import OmegaConf from omegaconf import OmegaConf
from projects.implicitron_trainer.impl.optimizer_factory import (
ImplicitronOptimizerFactory,
)
from .. import experiment from .. import experiment
from .utils import interactive_testing_requested, intercept_logs from .utils import interactive_testing_requested, intercept_logs
internal = os.environ.get("FB_TEST", False) internal = os.environ.get("FB_TEST", False)
@ -151,6 +155,16 @@ class TestExperiment(unittest.TestCase):
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)): with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
compose(file.name) compose(file.name)
def test_optimizer_factory(self):
model = torch.nn.Linear(2, 2)
adam, sched = ImplicitronOptimizerFactory(breed="Adam")(0, model)
self.assertIsInstance(adam, torch.optim.Adam)
sgd, sched = ImplicitronOptimizerFactory(breed="SGD")(0, model)
self.assertIsInstance(sgd, torch.optim.SGD)
adagrad, sched = ImplicitronOptimizerFactory(breed="Adagrad")(0, model)
self.assertIsInstance(adagrad, torch.optim.Adagrad)
class TestNerfRepro(unittest.TestCase): class TestNerfRepro(unittest.TestCase):
@unittest.skip("This test runs full blender training.") @unittest.skip("This test runs full blender training.")