Nikhila Ravi aa8b03f31d Updates to support Accelerate and multigpu training (#37)
Summary:
## Changes:
- Added Accelerate Library and refactored experiment.py to use it
- Needed to move `init_optimizer` and `ExperimentConfig` to a separate file to be compatible with submitit/hydra
- Needed to make some modifications to data loaders etc to work well with the accelerate ddp wrappers
- Loading/saving checkpoints incorporates an unwrapping step so remove the ddp wrapped model

## Tests

Tested with both `torchrun` and `submitit/hydra` on two gpus locally. Here are the commands:

**Torchrun**

Modules loaded:
```sh
1) anaconda3/2021.05   2) cuda/11.3   3) NCCL/2.9.8-3-cuda.11.3   4) gcc/5.2.0. (but unload gcc when using submit)
```

```sh
torchrun --nnodes=1 --nproc_per_node=2 experiment.py --config-path ./configs --config-name repro_singleseq_nerf_test
```

**Submitit/Hydra Local test**

```sh
~/pytorch3d/projects/implicitron_trainer$ HYDRA_FULL_ERROR=1 python3.9 experiment.py --config-name repro_singleseq_nerf_test --multirun --config-path ./configs  hydra/launcher=submitit_local hydra.launcher.gpus_per_node=2 hydra.launcher.tasks_per_node=2 hydra.launcher.nodes=1
```

**Submitit/Hydra distributed test**

```sh
~/implicitron/pytorch3d$ python3.9 experiment.py --config-name repro_singleseq_nerf_test --multirun --config-path ./configs  hydra/launcher=submitit_slurm hydra.launcher.gpus_per_node=8 hydra.launcher.tasks_per_node=8 hydra.launcher.nodes=1 hydra.launcher.partition=learnlab hydra.launcher.timeout_min=4320
```

## TODOS:
- Fix distributed evaluation: currently this doesn't work as the input format to the evaluation function is not suitable for gathering across gpus (needs to be nested list/tuple/dicts of objects that satisfy `is_torch_tensor`) and currently `frame_data`  contains `Cameras` type.
- Refactor the `accelerator` object to be accessible by all functions instead of needing to pass it around everywhere? Maybe have a `Trainer` class and add it as a method?
- Update readme with installation instructions for accelerate and also commands for running jobs with torchrun and submitit/hydra

X-link: https://github.com/fairinternal/pytorch3d/pull/37

Reviewed By: davnov134, kjchalup

Differential Revision: D37543870

Pulled By: bottler

fbshipit-source-id: be9eb4e91244d4fe3740d87dafec622ae1e0cf76
2022-07-11 19:29:58 -07:00

110 lines
3.7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, Optional, Tuple
import torch
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.tools.config import enable_get_default_args
logger = logging.getLogger(__name__)
def init_optimizer(
model: GenericModel,
optimizer_state: Optional[Dict[str, Any]],
last_epoch: int,
breed: str = "adam",
weight_decay: float = 0.0,
lr_policy: str = "multistep",
lr: float = 0.0005,
gamma: float = 0.1,
momentum: float = 0.9,
betas: Tuple[float, ...] = (0.9, 0.999),
milestones: tuple = (),
max_epochs: int = 1000,
):
"""
Initialize the optimizer (optionally from checkpoint state)
and the learning rate scheduler.
Args:
model: The model with optionally loaded weights
optimizer_state: The state dict for the optimizer. If None
it has not been loaded from checkpoint
last_epoch: If the model was loaded from checkpoint this will be the
number of the last epoch that was saved
breed: The type of optimizer to use e.g. adam
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
lr_policy: The policy to use for learning rate. Currently, only "multistep:
is supported.
lr: The value for the initial learning rate
gamma: Multiplicative factor of learning rate decay
momentum: Momentum factor for SGD optimizer
betas: Coefficients used for computing running averages of gradient and its square
in the Adam optimizer
milestones: List of increasing epoch indices at which the learning rate is
modified
max_epochs: The maximum number of epochs to run the optimizer for
Returns:
optimizer: Optimizer module, optionally loaded from checkpoint
scheduler: Learning rate scheduler module
Raise:
ValueError if `breed` or `lr_policy` are not supported.
"""
# Get the parameters to optimize
if hasattr(model, "_get_param_groups"): # use the model function
# pyre-ignore[29]
p_groups = model._get_param_groups(lr, wd=weight_decay)
else:
allprm = [prm for prm in model.parameters() if prm.requires_grad]
p_groups = [{"params": allprm, "lr": lr}]
# Intialize the optimizer
if breed == "sgd":
optimizer = torch.optim.SGD(
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
)
elif breed == "adagrad":
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
elif breed == "adam":
optimizer = torch.optim.Adam(
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
)
else:
raise ValueError("no such solver type %s" % breed)
logger.info(" -> solver type = %s" % breed)
# Load state from checkpoint
if optimizer_state is not None:
logger.info(" -> setting loaded optimizer state")
optimizer.load_state_dict(optimizer_state)
# Initialize the learning rate scheduler
if lr_policy == "multistep":
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=milestones,
gamma=gamma,
)
else:
raise ValueError("no such lr policy %s" % lr_policy)
# When loading from checkpoint, this will make sure that the
# lr is correctly set even after returning
for _ in range(last_epoch):
scheduler.step()
optimizer.zero_grad()
return optimizer, scheduler
enable_get_default_args(init_optimizer)