mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
foreach optimizers
Summary: Allow using the new `foreach` option on optimizers. Reviewed By: shapovalov Differential Revision: D39694843 fbshipit-source-id: 97109c245b669bc6edff0f246893f95b7ae71f90
This commit is contained in:
parent
db3c12abfb
commit
209c160a20
@ -4,6 +4,7 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
@ -61,6 +62,8 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
increasing epoch indices at which the learning rate is modified.
|
increasing epoch indices at which the learning rate is modified.
|
||||||
momentum: Momentum factor for SGD optimizer.
|
momentum: Momentum factor for SGD optimizer.
|
||||||
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
|
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
|
||||||
|
foreach: Whether to use new "foreach" implementation of optimizer where
|
||||||
|
available (e.g. requires PyTorch 1.12.0 for Adam)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
betas: Tuple[float, ...] = (0.9, 0.999)
|
betas: Tuple[float, ...] = (0.9, 0.999)
|
||||||
@ -74,6 +77,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
weight_decay: float = 0.0
|
weight_decay: float = 0.0
|
||||||
linear_exponential_lr_milestone: int = 200
|
linear_exponential_lr_milestone: int = 200
|
||||||
linear_exponential_start_gamma: float = 0.1
|
linear_exponential_start_gamma: float = 0.1
|
||||||
|
foreach: Optional[bool] = True
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
@ -115,23 +119,24 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
p_groups = [{"params": allprm, "lr": self.lr}]
|
p_groups = [{"params": allprm, "lr": self.lr}]
|
||||||
|
|
||||||
# Intialize the optimizer
|
# Intialize the optimizer
|
||||||
|
optimizer_kwargs: Dict[str, Any] = {
|
||||||
|
"lr": self.lr,
|
||||||
|
"weight_decay": self.weight_decay,
|
||||||
|
}
|
||||||
if self.breed == "SGD":
|
if self.breed == "SGD":
|
||||||
optimizer = torch.optim.SGD(
|
optimizer_class = torch.optim.SGD
|
||||||
p_groups,
|
optimizer_kwargs["momentum"] = self.momentum
|
||||||
lr=self.lr,
|
|
||||||
momentum=self.momentum,
|
|
||||||
weight_decay=self.weight_decay,
|
|
||||||
)
|
|
||||||
elif self.breed == "Adagrad":
|
elif self.breed == "Adagrad":
|
||||||
optimizer = torch.optim.Adagrad(
|
optimizer_class = torch.optim.Adagrad
|
||||||
p_groups, lr=self.lr, weight_decay=self.weight_decay
|
|
||||||
)
|
|
||||||
elif self.breed == "Adam":
|
elif self.breed == "Adam":
|
||||||
optimizer = torch.optim.Adam(
|
optimizer_class = torch.optim.Adam
|
||||||
p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay
|
optimizer_kwargs["betas"] = self.betas
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"No such solver type {self.breed}")
|
raise ValueError(f"No such solver type {self.breed}")
|
||||||
|
|
||||||
|
if "foreach" in inspect.signature(optimizer_class.__init__).parameters:
|
||||||
|
optimizer_kwargs["foreach"] = self.foreach
|
||||||
|
optimizer = optimizer_class(p_groups, **optimizer_kwargs)
|
||||||
logger.info(f"Solver type = {self.breed}")
|
logger.info(f"Solver type = {self.breed}")
|
||||||
|
|
||||||
# Load state from checkpoint
|
# Load state from checkpoint
|
||||||
|
@ -406,6 +406,7 @@ optimizer_factory_ImplicitronOptimizerFactory_args:
|
|||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
linear_exponential_lr_milestone: 200
|
linear_exponential_lr_milestone: 200
|
||||||
linear_exponential_start_gamma: 0.1
|
linear_exponential_start_gamma: 0.1
|
||||||
|
foreach: true
|
||||||
training_loop_ImplicitronTrainingLoop_args:
|
training_loop_ImplicitronTrainingLoop_args:
|
||||||
evaluator_class_type: ImplicitronEvaluator
|
evaluator_class_type: ImplicitronEvaluator
|
||||||
evaluator_ImplicitronEvaluator_args:
|
evaluator_ImplicitronEvaluator_args:
|
||||||
|
@ -9,13 +9,17 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from hydra import compose, initialize_config_dir
|
from hydra import compose, initialize_config_dir
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
from projects.implicitron_trainer.impl.optimizer_factory import (
|
||||||
|
ImplicitronOptimizerFactory,
|
||||||
|
)
|
||||||
|
|
||||||
from .. import experiment
|
from .. import experiment
|
||||||
from .utils import interactive_testing_requested, intercept_logs
|
from .utils import interactive_testing_requested, intercept_logs
|
||||||
|
|
||||||
|
|
||||||
internal = os.environ.get("FB_TEST", False)
|
internal = os.environ.get("FB_TEST", False)
|
||||||
|
|
||||||
|
|
||||||
@ -151,6 +155,16 @@ class TestExperiment(unittest.TestCase):
|
|||||||
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
|
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
|
||||||
compose(file.name)
|
compose(file.name)
|
||||||
|
|
||||||
|
def test_optimizer_factory(self):
|
||||||
|
model = torch.nn.Linear(2, 2)
|
||||||
|
|
||||||
|
adam, sched = ImplicitronOptimizerFactory(breed="Adam")(0, model)
|
||||||
|
self.assertIsInstance(adam, torch.optim.Adam)
|
||||||
|
sgd, sched = ImplicitronOptimizerFactory(breed="SGD")(0, model)
|
||||||
|
self.assertIsInstance(sgd, torch.optim.SGD)
|
||||||
|
adagrad, sched = ImplicitronOptimizerFactory(breed="Adagrad")(0, model)
|
||||||
|
self.assertIsInstance(adagrad, torch.optim.Adagrad)
|
||||||
|
|
||||||
|
|
||||||
class TestNerfRepro(unittest.TestCase):
|
class TestNerfRepro(unittest.TestCase):
|
||||||
@unittest.skip("This test runs full blender training.")
|
@unittest.skip("This test runs full blender training.")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user