Updates to support Accelerate and multigpu training (#37)

Summary:
## Changes:
- Added Accelerate Library and refactored experiment.py to use it
- Needed to move `init_optimizer` and `ExperimentConfig` to a separate file to be compatible with submitit/hydra
- Needed to make some modifications to data loaders etc to work well with the accelerate ddp wrappers
- Loading/saving checkpoints incorporates an unwrapping step so remove the ddp wrapped model

## Tests

Tested with both `torchrun` and `submitit/hydra` on two gpus locally. Here are the commands:

**Torchrun**

Modules loaded:
```sh
1) anaconda3/2021.05   2) cuda/11.3   3) NCCL/2.9.8-3-cuda.11.3   4) gcc/5.2.0. (but unload gcc when using submit)
```

```sh
torchrun --nnodes=1 --nproc_per_node=2 experiment.py --config-path ./configs --config-name repro_singleseq_nerf_test
```

**Submitit/Hydra Local test**

```sh
~/pytorch3d/projects/implicitron_trainer$ HYDRA_FULL_ERROR=1 python3.9 experiment.py --config-name repro_singleseq_nerf_test --multirun --config-path ./configs  hydra/launcher=submitit_local hydra.launcher.gpus_per_node=2 hydra.launcher.tasks_per_node=2 hydra.launcher.nodes=1
```

**Submitit/Hydra distributed test**

```sh
~/implicitron/pytorch3d$ python3.9 experiment.py --config-name repro_singleseq_nerf_test --multirun --config-path ./configs  hydra/launcher=submitit_slurm hydra.launcher.gpus_per_node=8 hydra.launcher.tasks_per_node=8 hydra.launcher.nodes=1 hydra.launcher.partition=learnlab hydra.launcher.timeout_min=4320
```

## TODOS:
- Fix distributed evaluation: currently this doesn't work as the input format to the evaluation function is not suitable for gathering across gpus (needs to be nested list/tuple/dicts of objects that satisfy `is_torch_tensor`) and currently `frame_data`  contains `Cameras` type.
- Refactor the `accelerator` object to be accessible by all functions instead of needing to pass it around everywhere? Maybe have a `Trainer` class and add it as a method?
- Update readme with installation instructions for accelerate and also commands for running jobs with torchrun and submitit/hydra

X-link: https://github.com/fairinternal/pytorch3d/pull/37

Reviewed By: davnov134, kjchalup

Differential Revision: D37543870

Pulled By: bottler

fbshipit-source-id: be9eb4e91244d4fe3740d87dafec622ae1e0cf76
This commit is contained in:
Nikhila Ravi 2022-07-11 19:29:58 -07:00 committed by Facebook GitHub Bot
parent 57a40b3688
commit aa8b03f31d
7 changed files with 290 additions and 162 deletions

View File

@ -45,7 +45,6 @@ The outputs of the experiment are saved and logged in multiple ways:
config file. config file.
""" """
import copy import copy
import json import json
import logging import logging
@ -53,7 +52,6 @@ import os
import random import random
import time import time
import warnings import warnings
from dataclasses import field
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import hydra import hydra
@ -61,6 +59,7 @@ import lpips
import numpy as np import numpy as np
import torch import torch
import tqdm import tqdm
from accelerate import Accelerator
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from packaging import version from packaging import version
from pytorch3d.implicitron.dataset import utils as ds_utils from pytorch3d.implicitron.dataset import utils as ds_utils
@ -69,17 +68,20 @@ from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Tas
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
from pytorch3d.implicitron.models.renderer.multipass_ea import (
MultiPassEmissionAbsorptionRenderer,
)
from pytorch3d.implicitron.models.renderer.ray_sampler import AdaptiveRaySampler
from pytorch3d.implicitron.tools import model_io, vis_utils from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import ( from pytorch3d.implicitron.tools.config import (
Configurable,
enable_get_default_args,
expand_args_fields, expand_args_fields,
get_default_args_field,
remove_unused_components, remove_unused_components,
) )
from pytorch3d.implicitron.tools.stats import Stats from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase from pytorch3d.renderer.cameras import CamerasBase
from .impl.experiment_config import ExperimentConfig
from .impl.optimization import init_optimizer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -101,6 +103,7 @@ def init_model(
force_load: bool = False, force_load: bool = False,
clear_stats: bool = False, clear_stats: bool = False,
load_model_only: bool = False, load_model_only: bool = False,
accelerator: Accelerator = None,
) -> Tuple[GenericModel, Stats, Optional[Dict[str, Any]]]: ) -> Tuple[GenericModel, Stats, Optional[Dict[str, Any]]]:
""" """
Returns an instance of `GenericModel`. Returns an instance of `GenericModel`.
@ -161,12 +164,20 @@ def init_model(
logger.info("found previous model %s" % model_path) logger.info("found previous model %s" % model_path)
if force_load or cfg.resume: if force_load or cfg.resume:
logger.info(" -> resuming") logger.info(" -> resuming")
map_location = None
if not accelerator.is_local_main_process:
map_location = {
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
}
if load_model_only: if load_model_only:
model_state_dict = torch.load(model_io.get_model_path(model_path)) model_state_dict = torch.load(
model_io.get_model_path(model_path), map_location=map_location
)
stats_load, optimizer_state = None, None stats_load, optimizer_state = None, None
else: else:
model_state_dict, stats_load, optimizer_state = model_io.load_model( model_state_dict, stats_load, optimizer_state = model_io.load_model(
model_path model_path, map_location=map_location
) )
# Determine if stats should be reset # Determine if stats should be reset
@ -210,101 +221,6 @@ def init_model(
return model, stats, optimizer_state return model, stats, optimizer_state
def init_optimizer(
model: GenericModel,
optimizer_state: Optional[Dict[str, Any]],
last_epoch: int,
breed: str = "adam",
weight_decay: float = 0.0,
lr_policy: str = "multistep",
lr: float = 0.0005,
gamma: float = 0.1,
momentum: float = 0.9,
betas: Tuple[float, ...] = (0.9, 0.999),
milestones: tuple = (),
max_epochs: int = 1000,
):
"""
Initialize the optimizer (optionally from checkpoint state)
and the learning rate scheduler.
Args:
model: The model with optionally loaded weights
optimizer_state: The state dict for the optimizer. If None
it has not been loaded from checkpoint
last_epoch: If the model was loaded from checkpoint this will be the
number of the last epoch that was saved
breed: The type of optimizer to use e.g. adam
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
lr_policy: The policy to use for learning rate. Currently, only "multistep:
is supported.
lr: The value for the initial learning rate
gamma: Multiplicative factor of learning rate decay
momentum: Momentum factor for SGD optimizer
betas: Coefficients used for computing running averages of gradient and its square
in the Adam optimizer
milestones: List of increasing epoch indices at which the learning rate is
modified
max_epochs: The maximum number of epochs to run the optimizer for
Returns:
optimizer: Optimizer module, optionally loaded from checkpoint
scheduler: Learning rate scheduler module
Raise:
ValueError if `breed` or `lr_policy` are not supported.
"""
# Get the parameters to optimize
if hasattr(model, "_get_param_groups"): # use the model function
# pyre-ignore[29]
p_groups = model._get_param_groups(lr, wd=weight_decay)
else:
allprm = [prm for prm in model.parameters() if prm.requires_grad]
p_groups = [{"params": allprm, "lr": lr}]
# Intialize the optimizer
if breed == "sgd":
optimizer = torch.optim.SGD(
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
)
elif breed == "adagrad":
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
elif breed == "adam":
optimizer = torch.optim.Adam(
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
)
else:
raise ValueError("no such solver type %s" % breed)
logger.info(" -> solver type = %s" % breed)
# Load state from checkpoint
if optimizer_state is not None:
logger.info(" -> setting loaded optimizer state")
optimizer.load_state_dict(optimizer_state)
# Initialize the learning rate scheduler
if lr_policy == "multistep":
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=milestones,
gamma=gamma,
)
else:
raise ValueError("no such lr policy %s" % lr_policy)
# When loading from checkpoint, this will make sure that the
# lr is correctly set even after returning
for _ in range(last_epoch):
scheduler.step()
optimizer.zero_grad()
return optimizer, scheduler
enable_get_default_args(init_optimizer)
def trainvalidate( def trainvalidate(
model, model,
stats, stats,
@ -318,6 +234,7 @@ def trainvalidate(
visdom_env_root: str = "trainvalidate", visdom_env_root: str = "trainvalidate",
clip_grad: float = 0.0, clip_grad: float = 0.0,
device: str = "cuda:0", device: str = "cuda:0",
accelerator: Accelerator = None,
**kwargs, **kwargs,
) -> None: ) -> None:
""" """
@ -365,11 +282,11 @@ def trainvalidate(
# Iterate through the batches # Iterate through the batches
n_batches = len(loader) n_batches = len(loader)
for it, batch in enumerate(loader): for it, net_input in enumerate(loader):
last_iter = it == n_batches - 1 last_iter = it == n_batches - 1
# move to gpu where possible (in place) # move to gpu where possible (in place)
net_input = batch.to(device) net_input = net_input.to(accelerator.device)
# run the forward pass # run the forward pass
if not validation: if not validation:
@ -395,7 +312,11 @@ def trainvalidate(
stats.print(stat_set=trainmode, max_it=n_batches) stats.print(stat_set=trainmode, max_it=n_batches)
# visualize results # visualize results
if visualize_interval > 0 and it % visualize_interval == 0: if (
accelerator.is_local_main_process
and visualize_interval > 0
and it % visualize_interval == 0
):
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}" prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
model.visualize( model.visualize(
@ -410,7 +331,7 @@ def trainvalidate(
loss = preds[bp_var] loss = preds[bp_var]
assert torch.isfinite(loss).all(), "Non-finite loss!" assert torch.isfinite(loss).all(), "Non-finite loss!"
# backprop # backprop
loss.backward() accelerator.backward(loss)
if clip_grad > 0.0: if clip_grad > 0.0:
# Optionally clip the gradient norms. # Optionally clip the gradient norms.
total_norm = torch.nn.utils.clip_grad_norm( total_norm = torch.nn.utils.clip_grad_norm(
@ -425,12 +346,22 @@ def trainvalidate(
optimizer.step() optimizer.step()
def run_training(cfg: DictConfig, device: str = "cpu") -> None: def run_training(cfg: DictConfig) -> None:
""" """
Entry point to run the training and validation loops Entry point to run the training and validation loops
based on the specified config file. based on the specified config file.
""" """
# Initialize the accelerator
accelerator = Accelerator(device_placement=False)
logger.info(accelerator.state)
device = accelerator.device
logger.info(f"Running experiment on device: {device}")
if accelerator.is_local_main_process:
logger.info(OmegaConf.to_yaml(cfg))
# set the debug mode # set the debug mode
if cfg.detect_anomaly: if cfg.detect_anomaly:
logger.info("Anomaly detection!") logger.info("Anomaly detection!")
@ -455,11 +386,11 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
all_train_cameras = datasource.get_all_train_cameras() all_train_cameras = datasource.get_all_train_cameras()
# init the model # init the model
model, stats, optimizer_state = init_model(cfg) model, stats, optimizer_state = init_model(cfg, accelerator=accelerator)
start_epoch = stats.epoch + 1 start_epoch = stats.epoch + 1
# move model to gpu # move model to gpu
model.to(device) model.to(accelerator.device)
# only run evaluation on the test dataloader # only run evaluation on the test dataloader
if cfg.eval_only: if cfg.eval_only:
@ -472,6 +403,7 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
model, model,
stats, stats,
device=device, device=device,
accelerator=accelerator,
) )
return return
@ -487,6 +419,16 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
assert scheduler.last_epoch == stats.epoch + 1 assert scheduler.last_epoch == stats.epoch + 1
assert scheduler.last_epoch == start_epoch assert scheduler.last_epoch == start_epoch
# Wrap all modules in the distributed library
# Note: we don't pass the scheduler to prepare as it
# doesn't need to be stepped at each optimizer step
(
model,
optimizer,
train_loader,
val_loader,
) = accelerator.prepare(model, optimizer, dataloaders.train, dataloaders.val)
past_scheduler_lrs = [] past_scheduler_lrs = []
# loop through epochs # loop through epochs
for epoch in range(start_epoch, cfg.solver_args.max_epochs): for epoch in range(start_epoch, cfg.solver_args.max_epochs):
@ -506,25 +448,27 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
model, model,
stats, stats,
epoch, epoch,
dataloaders.train, train_loader,
optimizer, optimizer,
False, False,
visdom_env_root=vis_utils.get_visdom_env(cfg), visdom_env_root=vis_utils.get_visdom_env(cfg),
device=device, device=device,
accelerator=accelerator,
**cfg, **cfg,
) )
# val loop (optional) # val loop (optional)
if dataloaders.val is not None and epoch % cfg.validation_interval == 0: if val_loader is not None and epoch % cfg.validation_interval == 0:
trainvalidate( trainvalidate(
model, model,
stats, stats,
epoch, epoch,
dataloaders.val, val_loader,
optimizer, optimizer,
True, True,
visdom_env_root=vis_utils.get_visdom_env(cfg), visdom_env_root=vis_utils.get_visdom_env(cfg),
device=device, device=device,
accelerator=accelerator,
**cfg, **cfg,
) )
@ -541,18 +485,22 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
task, task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks, camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
device=device, device=device,
accelerator=accelerator,
) )
assert stats.epoch == epoch, "inconsistent stats!" assert stats.epoch == epoch, "inconsistent stats!"
# delete previous models if required # delete previous models if required
# save model # save model only on the main process
if cfg.store_checkpoints: if cfg.store_checkpoints and accelerator.is_local_main_process:
if cfg.store_checkpoints_purge > 0: if cfg.store_checkpoints_purge > 0:
for prev_epoch in range(epoch - cfg.store_checkpoints_purge): for prev_epoch in range(epoch - cfg.store_checkpoints_purge):
model_io.purge_epoch(cfg.exp_dir, prev_epoch) model_io.purge_epoch(cfg.exp_dir, prev_epoch)
outfile = model_io.get_checkpoint(cfg.exp_dir, epoch) outfile = model_io.get_checkpoint(cfg.exp_dir, epoch)
model_io.safe_save_model(model, stats, outfile, optimizer=optimizer) unwrapped_model = accelerator.unwrap_model(model)
model_io.safe_save_model(
unwrapped_model, stats, outfile, optimizer=optimizer
)
scheduler.step() scheduler.step()
@ -582,6 +530,7 @@ def _eval_and_dump(
model, model,
stats, stats,
device, device,
accelerator: Accelerator = None,
) -> None: ) -> None:
""" """
Run the evaluation loop with the test data loader and Run the evaluation loop with the test data loader and
@ -600,6 +549,7 @@ def _eval_and_dump(
task, task,
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks, camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
device=device, device=device,
accelerator=accelerator,
) )
# add the evaluation epoch to the results # add the evaluation epoch to the results
@ -634,19 +584,20 @@ def _run_eval(
task: Task, task: Task,
camera_difficulty_bin_breaks: Tuple[float, float], camera_difficulty_bin_breaks: Tuple[float, float],
device, device,
accelerator: Accelerator = None,
): ):
""" """
Run the evaluation loop on the test dataloader Run the evaluation loop on the test dataloader
""" """
lpips_model = lpips.LPIPS(net="vgg") lpips_model = lpips.LPIPS(net="vgg")
lpips_model = lpips_model.to(device) lpips_model = lpips_model.to(accelerator.device)
model.eval() model.eval()
per_batch_eval_results = [] per_batch_eval_results = []
logger.info("Evaluating model ...") logger.info("Evaluating model ...")
for frame_data in tqdm.tqdm(loader): for frame_data in tqdm.tqdm(loader):
frame_data = frame_data.to(device) frame_data = frame_data.to(accelerator.device)
# mask out the unknown images so that the model does not see them # mask out the unknown images so that the model does not see them
frame_data_for_eval = _get_eval_frame_data(frame_data) frame_data_for_eval = _get_eval_frame_data(frame_data)
@ -655,7 +606,15 @@ def _run_eval(
preds = model( preds = model(
**{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION} **{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
) )
# TODO: Cannot use accelerate gather for two reasons:.
# (1) TypeError: Can't apply _gpu_gather_one on object of type
# <class 'pytorch3d.implicitron.models.base_model.ImplicitronRender'>,
# only of nested list/tuple/dicts of objects that satisfy is_torch_tensor.
# (2) Same error above but for frame_data which contains Cameras.
implicitron_render = copy.deepcopy(preds["implicitron_render"]) implicitron_render = copy.deepcopy(preds["implicitron_render"])
per_batch_eval_results.append( per_batch_eval_results.append(
evaluate.eval_batch( evaluate.eval_batch(
frame_data, frame_data,
@ -673,62 +632,65 @@ def _run_eval(
return category_result["results"] return category_result["results"]
def _seed_all_random_engines(seed: int): def _seed_all_random_engines(seed: int) -> None:
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
random.seed(seed) random.seed(seed)
class ExperimentConfig(Configurable): def _setup_envvars_for_cluster(cfg) -> bool:
generic_model_args: DictConfig = get_default_args_field(GenericModel) """
solver_args: DictConfig = get_default_args_field(init_optimizer) Prepares to run on cluster if relevant.
data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource) Returns whether FAIR cluster in use.
architecture: str = "generic" """
detect_anomaly: bool = False # TODO: How much of this is needed in general?
eval_only: bool = False
exp_dir: str = "./data/default_experiment/"
exp_idx: int = 0
gpu_idx: int = 0
metric_print_interval: int = 5
resume: bool = True
resume_epoch: int = -1
seed: int = 0
store_checkpoints: bool = True
store_checkpoints_purge: int = 1
test_interval: int = -1
test_when_finished: bool = False
validation_interval: int = 1
visdom_env: str = ""
visdom_port: int = 8097
visdom_server: str = "http://127.0.0.1"
visualize_interval: int = 1000
clip_grad: float = 0.0
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
hydra: dict = field( try:
default_factory=lambda: { import submitit
"run": {"dir": "."}, # Make hydra not change the working dir. except ImportError:
"output_subdir": None, # disable storing the .hydra logs return False
}
try:
# Only needed when launching on cluster with slurm and submitit
job_env = submitit.JobEnvironment()
except RuntimeError:
return False
os.environ["LOCAL_RANK"] = str(job_env.local_rank)
os.environ["RANK"] = str(job_env.global_rank)
os.environ["WORLD_SIZE"] = str(job_env.num_tasks)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "42918"
logger.info(
"Num tasks %s, global_rank %s"
% (str(job_env.num_tasks), str(job_env.global_rank))
) )
return True
expand_args_fields(ExperimentConfig) expand_args_fields(ExperimentConfig)
cs = hydra.core.config_store.ConfigStore.instance() cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config", node=ExperimentConfig) cs.store(name="default_config", node=ExperimentConfig)
@hydra.main(config_path="./configs/", config_name="default_config") @hydra.main(config_path="./configs/", config_name="default_config")
def experiment(cfg: DictConfig) -> None: def experiment(cfg: DictConfig) -> None:
# CUDA_VISIBLE_DEVICES must have been set.
if "CUDA_DEVICE_ORDER" not in os.environ:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
# Set the device if not _setup_envvars_for_cluster():
device = "cpu" logger.info("Running locally")
if torch.cuda.is_available() and cfg.gpu_idx < torch.cuda.device_count():
device = f"cuda:{cfg.gpu_idx}" # TODO: The following may be needed for hydra/submitit it to work
logger.info(f"Running experiment on device: {device}") expand_args_fields(GenericModel)
run_training(cfg, device) expand_args_fields(AdaptiveRaySampler)
expand_args_fields(MultiPassEmissionAbsorptionRenderer)
expand_args_fields(ImplicitronDataSource)
run_training(cfg)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

View File

@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import field
from typing import Tuple
from omegaconf import DictConfig
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.tools.config import Configurable, get_default_args_field
from .optimization import init_optimizer
class ExperimentConfig(Configurable):
generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(init_optimizer)
data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
architecture: str = "generic"
detect_anomaly: bool = False
eval_only: bool = False
exp_dir: str = "./data/default_experiment/"
exp_idx: int = 0
gpu_idx: int = 0
metric_print_interval: int = 5
resume: bool = True
resume_epoch: int = -1
seed: int = 0
store_checkpoints: bool = True
store_checkpoints_purge: int = 1
test_interval: int = -1
test_when_finished: bool = False
validation_interval: int = 1
visdom_env: str = ""
visdom_port: int = 8097
visdom_server: str = "http://127.0.0.1"
visualize_interval: int = 1000
clip_grad: float = 0.0
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
hydra: dict = field(
default_factory=lambda: {
"run": {"dir": "."}, # Make hydra not change the working dir.
"output_subdir": None, # disable storing the .hydra logs
}
)

View File

@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, Optional, Tuple
import torch
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.tools.config import enable_get_default_args
logger = logging.getLogger(__name__)
def init_optimizer(
model: GenericModel,
optimizer_state: Optional[Dict[str, Any]],
last_epoch: int,
breed: str = "adam",
weight_decay: float = 0.0,
lr_policy: str = "multistep",
lr: float = 0.0005,
gamma: float = 0.1,
momentum: float = 0.9,
betas: Tuple[float, ...] = (0.9, 0.999),
milestones: tuple = (),
max_epochs: int = 1000,
):
"""
Initialize the optimizer (optionally from checkpoint state)
and the learning rate scheduler.
Args:
model: The model with optionally loaded weights
optimizer_state: The state dict for the optimizer. If None
it has not been loaded from checkpoint
last_epoch: If the model was loaded from checkpoint this will be the
number of the last epoch that was saved
breed: The type of optimizer to use e.g. adam
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
lr_policy: The policy to use for learning rate. Currently, only "multistep:
is supported.
lr: The value for the initial learning rate
gamma: Multiplicative factor of learning rate decay
momentum: Momentum factor for SGD optimizer
betas: Coefficients used for computing running averages of gradient and its square
in the Adam optimizer
milestones: List of increasing epoch indices at which the learning rate is
modified
max_epochs: The maximum number of epochs to run the optimizer for
Returns:
optimizer: Optimizer module, optionally loaded from checkpoint
scheduler: Learning rate scheduler module
Raise:
ValueError if `breed` or `lr_policy` are not supported.
"""
# Get the parameters to optimize
if hasattr(model, "_get_param_groups"): # use the model function
# pyre-ignore[29]
p_groups = model._get_param_groups(lr, wd=weight_decay)
else:
allprm = [prm for prm in model.parameters() if prm.requires_grad]
p_groups = [{"params": allprm, "lr": lr}]
# Intialize the optimizer
if breed == "sgd":
optimizer = torch.optim.SGD(
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
)
elif breed == "adagrad":
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
elif breed == "adam":
optimizer = torch.optim.Adam(
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
)
else:
raise ValueError("no such solver type %s" % breed)
logger.info(" -> solver type = %s" % breed)
# Load state from checkpoint
if optimizer_state is not None:
logger.info(" -> setting loaded optimizer state")
optimizer.load_state_dict(optimizer_state)
# Initialize the learning rate scheduler
if lr_policy == "multistep":
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=milestones,
gamma=gamma,
)
else:
raise ValueError("no such lr policy %s" % lr_policy)
# When loading from checkpoint, this will make sure that the
# lr is correctly set even after returning
for _ in range(last_epoch):
scheduler.step()
optimizer.zero_grad()
return optimizer, scheduler
enable_get_default_args(init_optimizer)

View File

@ -8,11 +8,12 @@ import os
import unittest import unittest
from pathlib import Path from pathlib import Path
import experiment
import torch import torch
from hydra import compose, initialize_config_dir from hydra import compose, initialize_config_dir
from omegaconf import OmegaConf from omegaconf import OmegaConf
from .. import experiment
def interactive_testing_requested() -> bool: def interactive_testing_requested() -> bool:
""" """

View File

@ -21,7 +21,6 @@ from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as Fu import torch.nn.functional as Fu
from experiment import init_model
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
@ -38,6 +37,8 @@ from pytorch3d.implicitron.tools.vis_utils import (
) )
from tqdm import tqdm from tqdm import tqdm
from .experiment import init_model
def render_sequence( def render_sequence(
dataset: DatasetBase, dataset: DatasetBase,

View File

@ -9,6 +9,7 @@ import logging
import os import os
import shutil import shutil
import tempfile import tempfile
from typing import Optional
import torch import torch
@ -99,14 +100,14 @@ def save_model(model, stats, fl, optimizer=None, cfg=None):
return flstats, flmodel, flopt return flstats, flmodel, flopt
def load_model(fl): def load_model(fl, map_location: Optional[dict]):
flstats = get_stats_path(fl) flstats = get_stats_path(fl)
flmodel = get_model_path(fl) flmodel = get_model_path(fl)
flopt = get_optimizer_path(fl) flopt = get_optimizer_path(fl)
model_state_dict = torch.load(flmodel) model_state_dict = torch.load(flmodel, map_location=map_location)
stats = load_stats(flstats) stats = load_stats(flstats)
if os.path.isfile(flopt): if os.path.isfile(flopt):
optimizer = torch.load(flopt) optimizer = torch.load(flopt, map_location=map_location)
else: else:
optimizer = None optimizer = None