foreach optimizers

Summary: Allow using the new `foreach` option on optimizers.

Reviewed By: shapovalov

Differential Revision: D39694843

fbshipit-source-id: 97109c245b669bc6edff0f246893f95b7ae71f90
This commit is contained in:
Jeremy Reizenstein
2022-09-22 05:11:56 -07:00
committed by Facebook GitHub Bot
parent db3c12abfb
commit 209c160a20
3 changed files with 33 additions and 13 deletions

View File

@@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import inspect
import logging
import os
from typing import Any, Dict, Optional, Tuple
@@ -61,6 +62,8 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
increasing epoch indices at which the learning rate is modified.
momentum: Momentum factor for SGD optimizer.
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
foreach: Whether to use new "foreach" implementation of optimizer where
available (e.g. requires PyTorch 1.12.0 for Adam)
"""
betas: Tuple[float, ...] = (0.9, 0.999)
@@ -74,6 +77,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
weight_decay: float = 0.0
linear_exponential_lr_milestone: int = 200
linear_exponential_start_gamma: float = 0.1
foreach: Optional[bool] = True
def __post_init__(self):
run_auto_creation(self)
@@ -115,23 +119,24 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
p_groups = [{"params": allprm, "lr": self.lr}]
# Intialize the optimizer
optimizer_kwargs: Dict[str, Any] = {
"lr": self.lr,
"weight_decay": self.weight_decay,
}
if self.breed == "SGD":
optimizer = torch.optim.SGD(
p_groups,
lr=self.lr,
momentum=self.momentum,
weight_decay=self.weight_decay,
)
optimizer_class = torch.optim.SGD
optimizer_kwargs["momentum"] = self.momentum
elif self.breed == "Adagrad":
optimizer = torch.optim.Adagrad(
p_groups, lr=self.lr, weight_decay=self.weight_decay
)
optimizer_class = torch.optim.Adagrad
elif self.breed == "Adam":
optimizer = torch.optim.Adam(
p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay
)
optimizer_class = torch.optim.Adam
optimizer_kwargs["betas"] = self.betas
else:
raise ValueError(f"No such solver type {self.breed}")
if "foreach" in inspect.signature(optimizer_class.__init__).parameters:
optimizer_kwargs["foreach"] = self.foreach
optimizer = optimizer_class(p_groups, **optimizer_kwargs)
logger.info(f"Solver type = {self.breed}")
# Load state from checkpoint