foreach optimizers

Summary: Allow using the new `foreach` option on optimizers.

Reviewed By: shapovalov

Differential Revision: D39694843

fbshipit-source-id: 97109c245b669bc6edff0f246893f95b7ae71f90
This commit is contained in:
Jeremy Reizenstein 2022-09-22 05:11:56 -07:00 committed by Facebook GitHub Bot
parent db3c12abfb
commit 209c160a20
3 changed files with 33 additions and 13 deletions

View File

@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import inspect
import logging
import os
from typing import Any, Dict, Optional, Tuple
@ -61,6 +62,8 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
increasing epoch indices at which the learning rate is modified.
momentum: Momentum factor for SGD optimizer.
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
foreach: Whether to use new "foreach" implementation of optimizer where
available (e.g. requires PyTorch 1.12.0 for Adam)
"""
betas: Tuple[float, ...] = (0.9, 0.999)
@ -74,6 +77,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
weight_decay: float = 0.0
linear_exponential_lr_milestone: int = 200
linear_exponential_start_gamma: float = 0.1
foreach: Optional[bool] = True
def __post_init__(self):
run_auto_creation(self)
@ -115,23 +119,24 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
p_groups = [{"params": allprm, "lr": self.lr}]
# Intialize the optimizer
optimizer_kwargs: Dict[str, Any] = {
"lr": self.lr,
"weight_decay": self.weight_decay,
}
if self.breed == "SGD":
optimizer = torch.optim.SGD(
p_groups,
lr=self.lr,
momentum=self.momentum,
weight_decay=self.weight_decay,
)
optimizer_class = torch.optim.SGD
optimizer_kwargs["momentum"] = self.momentum
elif self.breed == "Adagrad":
optimizer = torch.optim.Adagrad(
p_groups, lr=self.lr, weight_decay=self.weight_decay
)
optimizer_class = torch.optim.Adagrad
elif self.breed == "Adam":
optimizer = torch.optim.Adam(
p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay
)
optimizer_class = torch.optim.Adam
optimizer_kwargs["betas"] = self.betas
else:
raise ValueError(f"No such solver type {self.breed}")
if "foreach" in inspect.signature(optimizer_class.__init__).parameters:
optimizer_kwargs["foreach"] = self.foreach
optimizer = optimizer_class(p_groups, **optimizer_kwargs)
logger.info(f"Solver type = {self.breed}")
# Load state from checkpoint

View File

@ -406,6 +406,7 @@ optimizer_factory_ImplicitronOptimizerFactory_args:
weight_decay: 0.0
linear_exponential_lr_milestone: 200
linear_exponential_start_gamma: 0.1
foreach: true
training_loop_ImplicitronTrainingLoop_args:
evaluator_class_type: ImplicitronEvaluator
evaluator_ImplicitronEvaluator_args:

View File

@ -9,13 +9,17 @@ import tempfile
import unittest
from pathlib import Path
import torch
from hydra import compose, initialize_config_dir
from omegaconf import OmegaConf
from projects.implicitron_trainer.impl.optimizer_factory import (
ImplicitronOptimizerFactory,
)
from .. import experiment
from .utils import interactive_testing_requested, intercept_logs
internal = os.environ.get("FB_TEST", False)
@ -151,6 +155,16 @@ class TestExperiment(unittest.TestCase):
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
compose(file.name)
def test_optimizer_factory(self):
model = torch.nn.Linear(2, 2)
adam, sched = ImplicitronOptimizerFactory(breed="Adam")(0, model)
self.assertIsInstance(adam, torch.optim.Adam)
sgd, sched = ImplicitronOptimizerFactory(breed="SGD")(0, model)
self.assertIsInstance(sgd, torch.optim.SGD)
adagrad, sched = ImplicitronOptimizerFactory(breed="Adagrad")(0, model)
self.assertIsInstance(adagrad, torch.optim.Adagrad)
class TestNerfRepro(unittest.TestCase):
@unittest.skip("This test runs full blender training.")