mirror of
				https://github.com/facebookresearch/pytorch3d.git
				synced 2025-11-04 18:02:14 +08:00 
			
		
		
		
	foreach optimizers
Summary: Allow using the new `foreach` option on optimizers. Reviewed By: shapovalov Differential Revision: D39694843 fbshipit-source-id: 97109c245b669bc6edff0f246893f95b7ae71f90
This commit is contained in:
		
							parent
							
								
									db3c12abfb
								
							
						
					
					
						commit
						209c160a20
					
				@ -4,6 +4,7 @@
 | 
			
		||||
# This source code is licensed under the BSD-style license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import inspect
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
from typing import Any, Dict, Optional, Tuple
 | 
			
		||||
@ -61,6 +62,8 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
 | 
			
		||||
            increasing epoch indices at which the learning rate is modified.
 | 
			
		||||
        momentum: Momentum factor for SGD optimizer.
 | 
			
		||||
        weight_decay: The optimizer weight_decay (L2 penalty on model weights).
 | 
			
		||||
        foreach: Whether to use new "foreach" implementation of optimizer where
 | 
			
		||||
            available (e.g. requires PyTorch 1.12.0 for Adam)
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    betas: Tuple[float, ...] = (0.9, 0.999)
 | 
			
		||||
@ -74,6 +77,7 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
 | 
			
		||||
    weight_decay: float = 0.0
 | 
			
		||||
    linear_exponential_lr_milestone: int = 200
 | 
			
		||||
    linear_exponential_start_gamma: float = 0.1
 | 
			
		||||
    foreach: Optional[bool] = True
 | 
			
		||||
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        run_auto_creation(self)
 | 
			
		||||
@ -115,23 +119,24 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
 | 
			
		||||
            p_groups = [{"params": allprm, "lr": self.lr}]
 | 
			
		||||
 | 
			
		||||
        # Intialize the optimizer
 | 
			
		||||
        optimizer_kwargs: Dict[str, Any] = {
 | 
			
		||||
            "lr": self.lr,
 | 
			
		||||
            "weight_decay": self.weight_decay,
 | 
			
		||||
        }
 | 
			
		||||
        if self.breed == "SGD":
 | 
			
		||||
            optimizer = torch.optim.SGD(
 | 
			
		||||
                p_groups,
 | 
			
		||||
                lr=self.lr,
 | 
			
		||||
                momentum=self.momentum,
 | 
			
		||||
                weight_decay=self.weight_decay,
 | 
			
		||||
            )
 | 
			
		||||
            optimizer_class = torch.optim.SGD
 | 
			
		||||
            optimizer_kwargs["momentum"] = self.momentum
 | 
			
		||||
        elif self.breed == "Adagrad":
 | 
			
		||||
            optimizer = torch.optim.Adagrad(
 | 
			
		||||
                p_groups, lr=self.lr, weight_decay=self.weight_decay
 | 
			
		||||
            )
 | 
			
		||||
            optimizer_class = torch.optim.Adagrad
 | 
			
		||||
        elif self.breed == "Adam":
 | 
			
		||||
            optimizer = torch.optim.Adam(
 | 
			
		||||
                p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay
 | 
			
		||||
            )
 | 
			
		||||
            optimizer_class = torch.optim.Adam
 | 
			
		||||
            optimizer_kwargs["betas"] = self.betas
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError(f"No such solver type {self.breed}")
 | 
			
		||||
 | 
			
		||||
        if "foreach" in inspect.signature(optimizer_class.__init__).parameters:
 | 
			
		||||
            optimizer_kwargs["foreach"] = self.foreach
 | 
			
		||||
        optimizer = optimizer_class(p_groups, **optimizer_kwargs)
 | 
			
		||||
        logger.info(f"Solver type = {self.breed}")
 | 
			
		||||
 | 
			
		||||
        # Load state from checkpoint
 | 
			
		||||
 | 
			
		||||
@ -406,6 +406,7 @@ optimizer_factory_ImplicitronOptimizerFactory_args:
 | 
			
		||||
  weight_decay: 0.0
 | 
			
		||||
  linear_exponential_lr_milestone: 200
 | 
			
		||||
  linear_exponential_start_gamma: 0.1
 | 
			
		||||
  foreach: true
 | 
			
		||||
training_loop_ImplicitronTrainingLoop_args:
 | 
			
		||||
  evaluator_class_type: ImplicitronEvaluator
 | 
			
		||||
  evaluator_ImplicitronEvaluator_args:
 | 
			
		||||
 | 
			
		||||
@ -9,13 +9,17 @@ import tempfile
 | 
			
		||||
import unittest
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from hydra import compose, initialize_config_dir
 | 
			
		||||
from omegaconf import OmegaConf
 | 
			
		||||
from projects.implicitron_trainer.impl.optimizer_factory import (
 | 
			
		||||
    ImplicitronOptimizerFactory,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
from .. import experiment
 | 
			
		||||
from .utils import interactive_testing_requested, intercept_logs
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
internal = os.environ.get("FB_TEST", False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -151,6 +155,16 @@ class TestExperiment(unittest.TestCase):
 | 
			
		||||
                with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
 | 
			
		||||
                    compose(file.name)
 | 
			
		||||
 | 
			
		||||
    def test_optimizer_factory(self):
 | 
			
		||||
        model = torch.nn.Linear(2, 2)
 | 
			
		||||
 | 
			
		||||
        adam, sched = ImplicitronOptimizerFactory(breed="Adam")(0, model)
 | 
			
		||||
        self.assertIsInstance(adam, torch.optim.Adam)
 | 
			
		||||
        sgd, sched = ImplicitronOptimizerFactory(breed="SGD")(0, model)
 | 
			
		||||
        self.assertIsInstance(sgd, torch.optim.SGD)
 | 
			
		||||
        adagrad, sched = ImplicitronOptimizerFactory(breed="Adagrad")(0, model)
 | 
			
		||||
        self.assertIsInstance(adagrad, torch.optim.Adagrad)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TestNerfRepro(unittest.TestCase):
 | 
			
		||||
    @unittest.skip("This test runs full blender training.")
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user