mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 20:02:49 +08:00
Summary: ## Changes: - Added Accelerate Library and refactored experiment.py to use it - Needed to move `init_optimizer` and `ExperimentConfig` to a separate file to be compatible with submitit/hydra - Needed to make some modifications to data loaders etc to work well with the accelerate ddp wrappers - Loading/saving checkpoints incorporates an unwrapping step so remove the ddp wrapped model ## Tests Tested with both `torchrun` and `submitit/hydra` on two gpus locally. Here are the commands: **Torchrun** Modules loaded: ```sh 1) anaconda3/2021.05 2) cuda/11.3 3) NCCL/2.9.8-3-cuda.11.3 4) gcc/5.2.0. (but unload gcc when using submit) ``` ```sh torchrun --nnodes=1 --nproc_per_node=2 experiment.py --config-path ./configs --config-name repro_singleseq_nerf_test ``` **Submitit/Hydra Local test** ```sh ~/pytorch3d/projects/implicitron_trainer$ HYDRA_FULL_ERROR=1 python3.9 experiment.py --config-name repro_singleseq_nerf_test --multirun --config-path ./configs hydra/launcher=submitit_local hydra.launcher.gpus_per_node=2 hydra.launcher.tasks_per_node=2 hydra.launcher.nodes=1 ``` **Submitit/Hydra distributed test** ```sh ~/implicitron/pytorch3d$ python3.9 experiment.py --config-name repro_singleseq_nerf_test --multirun --config-path ./configs hydra/launcher=submitit_slurm hydra.launcher.gpus_per_node=8 hydra.launcher.tasks_per_node=8 hydra.launcher.nodes=1 hydra.launcher.partition=learnlab hydra.launcher.timeout_min=4320 ``` ## TODOS: - Fix distributed evaluation: currently this doesn't work as the input format to the evaluation function is not suitable for gathering across gpus (needs to be nested list/tuple/dicts of objects that satisfy `is_torch_tensor`) and currently `frame_data` contains `Cameras` type. - Refactor the `accelerator` object to be accessible by all functions instead of needing to pass it around everywhere? Maybe have a `Trainer` class and add it as a method? - Update readme with installation instructions for accelerate and also commands for running jobs with torchrun and submitit/hydra X-link: https://github.com/fairinternal/pytorch3d/pull/37 Reviewed By: davnov134, kjchalup Differential Revision: D37543870 Pulled By: bottler fbshipit-source-id: be9eb4e91244d4fe3740d87dafec622ae1e0cf76
110 lines
3.7 KiB
Python
110 lines
3.7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import logging
|
|
from typing import Any, Dict, Optional, Tuple
|
|
|
|
import torch
|
|
from pytorch3d.implicitron.models.generic_model import GenericModel
|
|
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def init_optimizer(
|
|
model: GenericModel,
|
|
optimizer_state: Optional[Dict[str, Any]],
|
|
last_epoch: int,
|
|
breed: str = "adam",
|
|
weight_decay: float = 0.0,
|
|
lr_policy: str = "multistep",
|
|
lr: float = 0.0005,
|
|
gamma: float = 0.1,
|
|
momentum: float = 0.9,
|
|
betas: Tuple[float, ...] = (0.9, 0.999),
|
|
milestones: tuple = (),
|
|
max_epochs: int = 1000,
|
|
):
|
|
"""
|
|
Initialize the optimizer (optionally from checkpoint state)
|
|
and the learning rate scheduler.
|
|
|
|
Args:
|
|
model: The model with optionally loaded weights
|
|
optimizer_state: The state dict for the optimizer. If None
|
|
it has not been loaded from checkpoint
|
|
last_epoch: If the model was loaded from checkpoint this will be the
|
|
number of the last epoch that was saved
|
|
breed: The type of optimizer to use e.g. adam
|
|
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
|
|
lr_policy: The policy to use for learning rate. Currently, only "multistep:
|
|
is supported.
|
|
lr: The value for the initial learning rate
|
|
gamma: Multiplicative factor of learning rate decay
|
|
momentum: Momentum factor for SGD optimizer
|
|
betas: Coefficients used for computing running averages of gradient and its square
|
|
in the Adam optimizer
|
|
milestones: List of increasing epoch indices at which the learning rate is
|
|
modified
|
|
max_epochs: The maximum number of epochs to run the optimizer for
|
|
|
|
Returns:
|
|
optimizer: Optimizer module, optionally loaded from checkpoint
|
|
scheduler: Learning rate scheduler module
|
|
|
|
Raise:
|
|
ValueError if `breed` or `lr_policy` are not supported.
|
|
"""
|
|
|
|
# Get the parameters to optimize
|
|
if hasattr(model, "_get_param_groups"): # use the model function
|
|
# pyre-ignore[29]
|
|
p_groups = model._get_param_groups(lr, wd=weight_decay)
|
|
else:
|
|
allprm = [prm for prm in model.parameters() if prm.requires_grad]
|
|
p_groups = [{"params": allprm, "lr": lr}]
|
|
|
|
# Intialize the optimizer
|
|
if breed == "sgd":
|
|
optimizer = torch.optim.SGD(
|
|
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
|
|
)
|
|
elif breed == "adagrad":
|
|
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
|
|
elif breed == "adam":
|
|
optimizer = torch.optim.Adam(
|
|
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
|
|
)
|
|
else:
|
|
raise ValueError("no such solver type %s" % breed)
|
|
logger.info(" -> solver type = %s" % breed)
|
|
|
|
# Load state from checkpoint
|
|
if optimizer_state is not None:
|
|
logger.info(" -> setting loaded optimizer state")
|
|
optimizer.load_state_dict(optimizer_state)
|
|
|
|
# Initialize the learning rate scheduler
|
|
if lr_policy == "multistep":
|
|
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
|
optimizer,
|
|
milestones=milestones,
|
|
gamma=gamma,
|
|
)
|
|
else:
|
|
raise ValueError("no such lr policy %s" % lr_policy)
|
|
|
|
# When loading from checkpoint, this will make sure that the
|
|
# lr is correctly set even after returning
|
|
for _ in range(last_epoch):
|
|
scheduler.step()
|
|
|
|
optimizer.zero_grad()
|
|
return optimizer, scheduler
|
|
|
|
|
|
enable_get_default_args(init_optimizer)
|