mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Updates to support Accelerate and multigpu training (#37)
Summary: ## Changes: - Added Accelerate Library and refactored experiment.py to use it - Needed to move `init_optimizer` and `ExperimentConfig` to a separate file to be compatible with submitit/hydra - Needed to make some modifications to data loaders etc to work well with the accelerate ddp wrappers - Loading/saving checkpoints incorporates an unwrapping step so remove the ddp wrapped model ## Tests Tested with both `torchrun` and `submitit/hydra` on two gpus locally. Here are the commands: **Torchrun** Modules loaded: ```sh 1) anaconda3/2021.05 2) cuda/11.3 3) NCCL/2.9.8-3-cuda.11.3 4) gcc/5.2.0. (but unload gcc when using submit) ``` ```sh torchrun --nnodes=1 --nproc_per_node=2 experiment.py --config-path ./configs --config-name repro_singleseq_nerf_test ``` **Submitit/Hydra Local test** ```sh ~/pytorch3d/projects/implicitron_trainer$ HYDRA_FULL_ERROR=1 python3.9 experiment.py --config-name repro_singleseq_nerf_test --multirun --config-path ./configs hydra/launcher=submitit_local hydra.launcher.gpus_per_node=2 hydra.launcher.tasks_per_node=2 hydra.launcher.nodes=1 ``` **Submitit/Hydra distributed test** ```sh ~/implicitron/pytorch3d$ python3.9 experiment.py --config-name repro_singleseq_nerf_test --multirun --config-path ./configs hydra/launcher=submitit_slurm hydra.launcher.gpus_per_node=8 hydra.launcher.tasks_per_node=8 hydra.launcher.nodes=1 hydra.launcher.partition=learnlab hydra.launcher.timeout_min=4320 ``` ## TODOS: - Fix distributed evaluation: currently this doesn't work as the input format to the evaluation function is not suitable for gathering across gpus (needs to be nested list/tuple/dicts of objects that satisfy `is_torch_tensor`) and currently `frame_data` contains `Cameras` type. - Refactor the `accelerator` object to be accessible by all functions instead of needing to pass it around everywhere? Maybe have a `Trainer` class and add it as a method? - Update readme with installation instructions for accelerate and also commands for running jobs with torchrun and submitit/hydra X-link: https://github.com/fairinternal/pytorch3d/pull/37 Reviewed By: davnov134, kjchalup Differential Revision: D37543870 Pulled By: bottler fbshipit-source-id: be9eb4e91244d4fe3740d87dafec622ae1e0cf76
This commit is contained in:
parent
57a40b3688
commit
aa8b03f31d
@ -45,7 +45,6 @@ The outputs of the experiment are saved and logged in multiple ways:
|
||||
config file.
|
||||
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
@ -53,7 +52,6 @@ import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from dataclasses import field
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import hydra
|
||||
@ -61,6 +59,7 @@ import lpips
|
||||
import numpy as np
|
||||
import torch
|
||||
import tqdm
|
||||
from accelerate import Accelerator
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from packaging import version
|
||||
from pytorch3d.implicitron.dataset import utils as ds_utils
|
||||
@ -69,17 +68,20 @@ from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource, Tas
|
||||
from pytorch3d.implicitron.dataset.dataset_map_provider import DatasetMap
|
||||
from pytorch3d.implicitron.evaluation import evaluate_new_view_synthesis as evaluate
|
||||
from pytorch3d.implicitron.models.generic_model import EvaluationMode, GenericModel
|
||||
from pytorch3d.implicitron.models.renderer.multipass_ea import (
|
||||
MultiPassEmissionAbsorptionRenderer,
|
||||
)
|
||||
from pytorch3d.implicitron.models.renderer.ray_sampler import AdaptiveRaySampler
|
||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
Configurable,
|
||||
enable_get_default_args,
|
||||
expand_args_fields,
|
||||
get_default_args_field,
|
||||
remove_unused_components,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.stats import Stats
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
|
||||
from .impl.experiment_config import ExperimentConfig
|
||||
from .impl.optimization import init_optimizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -101,6 +103,7 @@ def init_model(
|
||||
force_load: bool = False,
|
||||
clear_stats: bool = False,
|
||||
load_model_only: bool = False,
|
||||
accelerator: Accelerator = None,
|
||||
) -> Tuple[GenericModel, Stats, Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Returns an instance of `GenericModel`.
|
||||
@ -161,12 +164,20 @@ def init_model(
|
||||
logger.info("found previous model %s" % model_path)
|
||||
if force_load or cfg.resume:
|
||||
logger.info(" -> resuming")
|
||||
|
||||
map_location = None
|
||||
if not accelerator.is_local_main_process:
|
||||
map_location = {
|
||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||
}
|
||||
if load_model_only:
|
||||
model_state_dict = torch.load(model_io.get_model_path(model_path))
|
||||
model_state_dict = torch.load(
|
||||
model_io.get_model_path(model_path), map_location=map_location
|
||||
)
|
||||
stats_load, optimizer_state = None, None
|
||||
else:
|
||||
model_state_dict, stats_load, optimizer_state = model_io.load_model(
|
||||
model_path
|
||||
model_path, map_location=map_location
|
||||
)
|
||||
|
||||
# Determine if stats should be reset
|
||||
@ -210,101 +221,6 @@ def init_model(
|
||||
return model, stats, optimizer_state
|
||||
|
||||
|
||||
def init_optimizer(
|
||||
model: GenericModel,
|
||||
optimizer_state: Optional[Dict[str, Any]],
|
||||
last_epoch: int,
|
||||
breed: str = "adam",
|
||||
weight_decay: float = 0.0,
|
||||
lr_policy: str = "multistep",
|
||||
lr: float = 0.0005,
|
||||
gamma: float = 0.1,
|
||||
momentum: float = 0.9,
|
||||
betas: Tuple[float, ...] = (0.9, 0.999),
|
||||
milestones: tuple = (),
|
||||
max_epochs: int = 1000,
|
||||
):
|
||||
"""
|
||||
Initialize the optimizer (optionally from checkpoint state)
|
||||
and the learning rate scheduler.
|
||||
|
||||
Args:
|
||||
model: The model with optionally loaded weights
|
||||
optimizer_state: The state dict for the optimizer. If None
|
||||
it has not been loaded from checkpoint
|
||||
last_epoch: If the model was loaded from checkpoint this will be the
|
||||
number of the last epoch that was saved
|
||||
breed: The type of optimizer to use e.g. adam
|
||||
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
|
||||
lr_policy: The policy to use for learning rate. Currently, only "multistep:
|
||||
is supported.
|
||||
lr: The value for the initial learning rate
|
||||
gamma: Multiplicative factor of learning rate decay
|
||||
momentum: Momentum factor for SGD optimizer
|
||||
betas: Coefficients used for computing running averages of gradient and its square
|
||||
in the Adam optimizer
|
||||
milestones: List of increasing epoch indices at which the learning rate is
|
||||
modified
|
||||
max_epochs: The maximum number of epochs to run the optimizer for
|
||||
|
||||
Returns:
|
||||
optimizer: Optimizer module, optionally loaded from checkpoint
|
||||
scheduler: Learning rate scheduler module
|
||||
|
||||
Raise:
|
||||
ValueError if `breed` or `lr_policy` are not supported.
|
||||
"""
|
||||
|
||||
# Get the parameters to optimize
|
||||
if hasattr(model, "_get_param_groups"): # use the model function
|
||||
# pyre-ignore[29]
|
||||
p_groups = model._get_param_groups(lr, wd=weight_decay)
|
||||
else:
|
||||
allprm = [prm for prm in model.parameters() if prm.requires_grad]
|
||||
p_groups = [{"params": allprm, "lr": lr}]
|
||||
|
||||
# Intialize the optimizer
|
||||
if breed == "sgd":
|
||||
optimizer = torch.optim.SGD(
|
||||
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
|
||||
)
|
||||
elif breed == "adagrad":
|
||||
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
|
||||
elif breed == "adam":
|
||||
optimizer = torch.optim.Adam(
|
||||
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
|
||||
)
|
||||
else:
|
||||
raise ValueError("no such solver type %s" % breed)
|
||||
logger.info(" -> solver type = %s" % breed)
|
||||
|
||||
# Load state from checkpoint
|
||||
if optimizer_state is not None:
|
||||
logger.info(" -> setting loaded optimizer state")
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
# Initialize the learning rate scheduler
|
||||
if lr_policy == "multistep":
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer,
|
||||
milestones=milestones,
|
||||
gamma=gamma,
|
||||
)
|
||||
else:
|
||||
raise ValueError("no such lr policy %s" % lr_policy)
|
||||
|
||||
# When loading from checkpoint, this will make sure that the
|
||||
# lr is correctly set even after returning
|
||||
for _ in range(last_epoch):
|
||||
scheduler.step()
|
||||
|
||||
optimizer.zero_grad()
|
||||
return optimizer, scheduler
|
||||
|
||||
|
||||
enable_get_default_args(init_optimizer)
|
||||
|
||||
|
||||
def trainvalidate(
|
||||
model,
|
||||
stats,
|
||||
@ -318,6 +234,7 @@ def trainvalidate(
|
||||
visdom_env_root: str = "trainvalidate",
|
||||
clip_grad: float = 0.0,
|
||||
device: str = "cuda:0",
|
||||
accelerator: Accelerator = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
@ -365,11 +282,11 @@ def trainvalidate(
|
||||
|
||||
# Iterate through the batches
|
||||
n_batches = len(loader)
|
||||
for it, batch in enumerate(loader):
|
||||
for it, net_input in enumerate(loader):
|
||||
last_iter = it == n_batches - 1
|
||||
|
||||
# move to gpu where possible (in place)
|
||||
net_input = batch.to(device)
|
||||
net_input = net_input.to(accelerator.device)
|
||||
|
||||
# run the forward pass
|
||||
if not validation:
|
||||
@ -395,7 +312,11 @@ def trainvalidate(
|
||||
stats.print(stat_set=trainmode, max_it=n_batches)
|
||||
|
||||
# visualize results
|
||||
if visualize_interval > 0 and it % visualize_interval == 0:
|
||||
if (
|
||||
accelerator.is_local_main_process
|
||||
and visualize_interval > 0
|
||||
and it % visualize_interval == 0
|
||||
):
|
||||
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
||||
|
||||
model.visualize(
|
||||
@ -410,7 +331,7 @@ def trainvalidate(
|
||||
loss = preds[bp_var]
|
||||
assert torch.isfinite(loss).all(), "Non-finite loss!"
|
||||
# backprop
|
||||
loss.backward()
|
||||
accelerator.backward(loss)
|
||||
if clip_grad > 0.0:
|
||||
# Optionally clip the gradient norms.
|
||||
total_norm = torch.nn.utils.clip_grad_norm(
|
||||
@ -425,12 +346,22 @@ def trainvalidate(
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
||||
def run_training(cfg: DictConfig) -> None:
|
||||
"""
|
||||
Entry point to run the training and validation loops
|
||||
based on the specified config file.
|
||||
"""
|
||||
|
||||
# Initialize the accelerator
|
||||
accelerator = Accelerator(device_placement=False)
|
||||
logger.info(accelerator.state)
|
||||
|
||||
device = accelerator.device
|
||||
logger.info(f"Running experiment on device: {device}")
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
logger.info(OmegaConf.to_yaml(cfg))
|
||||
|
||||
# set the debug mode
|
||||
if cfg.detect_anomaly:
|
||||
logger.info("Anomaly detection!")
|
||||
@ -455,11 +386,11 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
||||
all_train_cameras = datasource.get_all_train_cameras()
|
||||
|
||||
# init the model
|
||||
model, stats, optimizer_state = init_model(cfg)
|
||||
model, stats, optimizer_state = init_model(cfg, accelerator=accelerator)
|
||||
start_epoch = stats.epoch + 1
|
||||
|
||||
# move model to gpu
|
||||
model.to(device)
|
||||
model.to(accelerator.device)
|
||||
|
||||
# only run evaluation on the test dataloader
|
||||
if cfg.eval_only:
|
||||
@ -472,6 +403,7 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
||||
model,
|
||||
stats,
|
||||
device=device,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
return
|
||||
|
||||
@ -487,6 +419,16 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
||||
assert scheduler.last_epoch == stats.epoch + 1
|
||||
assert scheduler.last_epoch == start_epoch
|
||||
|
||||
# Wrap all modules in the distributed library
|
||||
# Note: we don't pass the scheduler to prepare as it
|
||||
# doesn't need to be stepped at each optimizer step
|
||||
(
|
||||
model,
|
||||
optimizer,
|
||||
train_loader,
|
||||
val_loader,
|
||||
) = accelerator.prepare(model, optimizer, dataloaders.train, dataloaders.val)
|
||||
|
||||
past_scheduler_lrs = []
|
||||
# loop through epochs
|
||||
for epoch in range(start_epoch, cfg.solver_args.max_epochs):
|
||||
@ -506,25 +448,27 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
||||
model,
|
||||
stats,
|
||||
epoch,
|
||||
dataloaders.train,
|
||||
train_loader,
|
||||
optimizer,
|
||||
False,
|
||||
visdom_env_root=vis_utils.get_visdom_env(cfg),
|
||||
device=device,
|
||||
accelerator=accelerator,
|
||||
**cfg,
|
||||
)
|
||||
|
||||
# val loop (optional)
|
||||
if dataloaders.val is not None and epoch % cfg.validation_interval == 0:
|
||||
if val_loader is not None and epoch % cfg.validation_interval == 0:
|
||||
trainvalidate(
|
||||
model,
|
||||
stats,
|
||||
epoch,
|
||||
dataloaders.val,
|
||||
val_loader,
|
||||
optimizer,
|
||||
True,
|
||||
visdom_env_root=vis_utils.get_visdom_env(cfg),
|
||||
device=device,
|
||||
accelerator=accelerator,
|
||||
**cfg,
|
||||
)
|
||||
|
||||
@ -541,18 +485,22 @@ def run_training(cfg: DictConfig, device: str = "cpu") -> None:
|
||||
task,
|
||||
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
|
||||
device=device,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
|
||||
assert stats.epoch == epoch, "inconsistent stats!"
|
||||
|
||||
# delete previous models if required
|
||||
# save model
|
||||
if cfg.store_checkpoints:
|
||||
# save model only on the main process
|
||||
if cfg.store_checkpoints and accelerator.is_local_main_process:
|
||||
if cfg.store_checkpoints_purge > 0:
|
||||
for prev_epoch in range(epoch - cfg.store_checkpoints_purge):
|
||||
model_io.purge_epoch(cfg.exp_dir, prev_epoch)
|
||||
outfile = model_io.get_checkpoint(cfg.exp_dir, epoch)
|
||||
model_io.safe_save_model(model, stats, outfile, optimizer=optimizer)
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
model_io.safe_save_model(
|
||||
unwrapped_model, stats, outfile, optimizer=optimizer
|
||||
)
|
||||
|
||||
scheduler.step()
|
||||
|
||||
@ -582,6 +530,7 @@ def _eval_and_dump(
|
||||
model,
|
||||
stats,
|
||||
device,
|
||||
accelerator: Accelerator = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the evaluation loop with the test data loader and
|
||||
@ -600,6 +549,7 @@ def _eval_and_dump(
|
||||
task,
|
||||
camera_difficulty_bin_breaks=cfg.camera_difficulty_bin_breaks,
|
||||
device=device,
|
||||
accelerator=accelerator,
|
||||
)
|
||||
|
||||
# add the evaluation epoch to the results
|
||||
@ -634,19 +584,20 @@ def _run_eval(
|
||||
task: Task,
|
||||
camera_difficulty_bin_breaks: Tuple[float, float],
|
||||
device,
|
||||
accelerator: Accelerator = None,
|
||||
):
|
||||
"""
|
||||
Run the evaluation loop on the test dataloader
|
||||
"""
|
||||
lpips_model = lpips.LPIPS(net="vgg")
|
||||
lpips_model = lpips_model.to(device)
|
||||
lpips_model = lpips_model.to(accelerator.device)
|
||||
|
||||
model.eval()
|
||||
|
||||
per_batch_eval_results = []
|
||||
logger.info("Evaluating model ...")
|
||||
for frame_data in tqdm.tqdm(loader):
|
||||
frame_data = frame_data.to(device)
|
||||
frame_data = frame_data.to(accelerator.device)
|
||||
|
||||
# mask out the unknown images so that the model does not see them
|
||||
frame_data_for_eval = _get_eval_frame_data(frame_data)
|
||||
@ -655,7 +606,15 @@ def _run_eval(
|
||||
preds = model(
|
||||
**{**frame_data_for_eval, "evaluation_mode": EvaluationMode.EVALUATION}
|
||||
)
|
||||
|
||||
# TODO: Cannot use accelerate gather for two reasons:.
|
||||
# (1) TypeError: Can't apply _gpu_gather_one on object of type
|
||||
# <class 'pytorch3d.implicitron.models.base_model.ImplicitronRender'>,
|
||||
# only of nested list/tuple/dicts of objects that satisfy is_torch_tensor.
|
||||
# (2) Same error above but for frame_data which contains Cameras.
|
||||
|
||||
implicitron_render = copy.deepcopy(preds["implicitron_render"])
|
||||
|
||||
per_batch_eval_results.append(
|
||||
evaluate.eval_batch(
|
||||
frame_data,
|
||||
@ -673,62 +632,65 @@ def _run_eval(
|
||||
return category_result["results"]
|
||||
|
||||
|
||||
def _seed_all_random_engines(seed: int):
|
||||
def _seed_all_random_engines(seed: int) -> None:
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
class ExperimentConfig(Configurable):
|
||||
generic_model_args: DictConfig = get_default_args_field(GenericModel)
|
||||
solver_args: DictConfig = get_default_args_field(init_optimizer)
|
||||
data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
|
||||
architecture: str = "generic"
|
||||
detect_anomaly: bool = False
|
||||
eval_only: bool = False
|
||||
exp_dir: str = "./data/default_experiment/"
|
||||
exp_idx: int = 0
|
||||
gpu_idx: int = 0
|
||||
metric_print_interval: int = 5
|
||||
resume: bool = True
|
||||
resume_epoch: int = -1
|
||||
seed: int = 0
|
||||
store_checkpoints: bool = True
|
||||
store_checkpoints_purge: int = 1
|
||||
test_interval: int = -1
|
||||
test_when_finished: bool = False
|
||||
validation_interval: int = 1
|
||||
visdom_env: str = ""
|
||||
visdom_port: int = 8097
|
||||
visdom_server: str = "http://127.0.0.1"
|
||||
visualize_interval: int = 1000
|
||||
clip_grad: float = 0.0
|
||||
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
|
||||
def _setup_envvars_for_cluster(cfg) -> bool:
|
||||
"""
|
||||
Prepares to run on cluster if relevant.
|
||||
Returns whether FAIR cluster in use.
|
||||
"""
|
||||
# TODO: How much of this is needed in general?
|
||||
|
||||
hydra: dict = field(
|
||||
default_factory=lambda: {
|
||||
"run": {"dir": "."}, # Make hydra not change the working dir.
|
||||
"output_subdir": None, # disable storing the .hydra logs
|
||||
}
|
||||
try:
|
||||
import submitit
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Only needed when launching on cluster with slurm and submitit
|
||||
job_env = submitit.JobEnvironment()
|
||||
except RuntimeError:
|
||||
return False
|
||||
|
||||
os.environ["LOCAL_RANK"] = str(job_env.local_rank)
|
||||
os.environ["RANK"] = str(job_env.global_rank)
|
||||
os.environ["WORLD_SIZE"] = str(job_env.num_tasks)
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = "42918"
|
||||
logger.info(
|
||||
"Num tasks %s, global_rank %s"
|
||||
% (str(job_env.num_tasks), str(job_env.global_rank))
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
expand_args_fields(ExperimentConfig)
|
||||
|
||||
cs = hydra.core.config_store.ConfigStore.instance()
|
||||
cs.store(name="default_config", node=ExperimentConfig)
|
||||
|
||||
|
||||
@hydra.main(config_path="./configs/", config_name="default_config")
|
||||
def experiment(cfg: DictConfig) -> None:
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
|
||||
# Set the device
|
||||
device = "cpu"
|
||||
if torch.cuda.is_available() and cfg.gpu_idx < torch.cuda.device_count():
|
||||
device = f"cuda:{cfg.gpu_idx}"
|
||||
logger.info(f"Running experiment on device: {device}")
|
||||
run_training(cfg, device)
|
||||
# CUDA_VISIBLE_DEVICES must have been set.
|
||||
|
||||
if "CUDA_DEVICE_ORDER" not in os.environ:
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
|
||||
if not _setup_envvars_for_cluster():
|
||||
logger.info("Running locally")
|
||||
|
||||
# TODO: The following may be needed for hydra/submitit it to work
|
||||
expand_args_fields(GenericModel)
|
||||
expand_args_fields(AdaptiveRaySampler)
|
||||
expand_args_fields(MultiPassEmissionAbsorptionRenderer)
|
||||
expand_args_fields(ImplicitronDataSource)
|
||||
|
||||
run_training(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
5
projects/implicitron_trainer/impl/__init__.py
Normal file
5
projects/implicitron_trainer/impl/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
49
projects/implicitron_trainer/impl/experiment_config.py
Normal file
49
projects/implicitron_trainer/impl/experiment_config.py
Normal file
@ -0,0 +1,49 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Tuple
|
||||
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||
from pytorch3d.implicitron.tools.config import Configurable, get_default_args_field
|
||||
|
||||
from .optimization import init_optimizer
|
||||
|
||||
|
||||
class ExperimentConfig(Configurable):
|
||||
generic_model_args: DictConfig = get_default_args_field(GenericModel)
|
||||
solver_args: DictConfig = get_default_args_field(init_optimizer)
|
||||
data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
|
||||
architecture: str = "generic"
|
||||
detect_anomaly: bool = False
|
||||
eval_only: bool = False
|
||||
exp_dir: str = "./data/default_experiment/"
|
||||
exp_idx: int = 0
|
||||
gpu_idx: int = 0
|
||||
metric_print_interval: int = 5
|
||||
resume: bool = True
|
||||
resume_epoch: int = -1
|
||||
seed: int = 0
|
||||
store_checkpoints: bool = True
|
||||
store_checkpoints_purge: int = 1
|
||||
test_interval: int = -1
|
||||
test_when_finished: bool = False
|
||||
validation_interval: int = 1
|
||||
visdom_env: str = ""
|
||||
visdom_port: int = 8097
|
||||
visdom_server: str = "http://127.0.0.1"
|
||||
visualize_interval: int = 1000
|
||||
clip_grad: float = 0.0
|
||||
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
|
||||
|
||||
hydra: dict = field(
|
||||
default_factory=lambda: {
|
||||
"run": {"dir": "."}, # Make hydra not change the working dir.
|
||||
"output_subdir": None, # disable storing the .hydra logs
|
||||
}
|
||||
)
|
109
projects/implicitron_trainer/impl/optimization.py
Normal file
109
projects/implicitron_trainer/impl/optimization.py
Normal file
@ -0,0 +1,109 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init_optimizer(
|
||||
model: GenericModel,
|
||||
optimizer_state: Optional[Dict[str, Any]],
|
||||
last_epoch: int,
|
||||
breed: str = "adam",
|
||||
weight_decay: float = 0.0,
|
||||
lr_policy: str = "multistep",
|
||||
lr: float = 0.0005,
|
||||
gamma: float = 0.1,
|
||||
momentum: float = 0.9,
|
||||
betas: Tuple[float, ...] = (0.9, 0.999),
|
||||
milestones: tuple = (),
|
||||
max_epochs: int = 1000,
|
||||
):
|
||||
"""
|
||||
Initialize the optimizer (optionally from checkpoint state)
|
||||
and the learning rate scheduler.
|
||||
|
||||
Args:
|
||||
model: The model with optionally loaded weights
|
||||
optimizer_state: The state dict for the optimizer. If None
|
||||
it has not been loaded from checkpoint
|
||||
last_epoch: If the model was loaded from checkpoint this will be the
|
||||
number of the last epoch that was saved
|
||||
breed: The type of optimizer to use e.g. adam
|
||||
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
|
||||
lr_policy: The policy to use for learning rate. Currently, only "multistep:
|
||||
is supported.
|
||||
lr: The value for the initial learning rate
|
||||
gamma: Multiplicative factor of learning rate decay
|
||||
momentum: Momentum factor for SGD optimizer
|
||||
betas: Coefficients used for computing running averages of gradient and its square
|
||||
in the Adam optimizer
|
||||
milestones: List of increasing epoch indices at which the learning rate is
|
||||
modified
|
||||
max_epochs: The maximum number of epochs to run the optimizer for
|
||||
|
||||
Returns:
|
||||
optimizer: Optimizer module, optionally loaded from checkpoint
|
||||
scheduler: Learning rate scheduler module
|
||||
|
||||
Raise:
|
||||
ValueError if `breed` or `lr_policy` are not supported.
|
||||
"""
|
||||
|
||||
# Get the parameters to optimize
|
||||
if hasattr(model, "_get_param_groups"): # use the model function
|
||||
# pyre-ignore[29]
|
||||
p_groups = model._get_param_groups(lr, wd=weight_decay)
|
||||
else:
|
||||
allprm = [prm for prm in model.parameters() if prm.requires_grad]
|
||||
p_groups = [{"params": allprm, "lr": lr}]
|
||||
|
||||
# Intialize the optimizer
|
||||
if breed == "sgd":
|
||||
optimizer = torch.optim.SGD(
|
||||
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
|
||||
)
|
||||
elif breed == "adagrad":
|
||||
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
|
||||
elif breed == "adam":
|
||||
optimizer = torch.optim.Adam(
|
||||
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
|
||||
)
|
||||
else:
|
||||
raise ValueError("no such solver type %s" % breed)
|
||||
logger.info(" -> solver type = %s" % breed)
|
||||
|
||||
# Load state from checkpoint
|
||||
if optimizer_state is not None:
|
||||
logger.info(" -> setting loaded optimizer state")
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
# Initialize the learning rate scheduler
|
||||
if lr_policy == "multistep":
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer,
|
||||
milestones=milestones,
|
||||
gamma=gamma,
|
||||
)
|
||||
else:
|
||||
raise ValueError("no such lr policy %s" % lr_policy)
|
||||
|
||||
# When loading from checkpoint, this will make sure that the
|
||||
# lr is correctly set even after returning
|
||||
for _ in range(last_epoch):
|
||||
scheduler.step()
|
||||
|
||||
optimizer.zero_grad()
|
||||
return optimizer, scheduler
|
||||
|
||||
|
||||
enable_get_default_args(init_optimizer)
|
@ -8,11 +8,12 @@ import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import experiment
|
||||
import torch
|
||||
from hydra import compose, initialize_config_dir
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from .. import experiment
|
||||
|
||||
|
||||
def interactive_testing_requested() -> bool:
|
||||
"""
|
||||
|
@ -21,7 +21,6 @@ from typing import Optional, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as Fu
|
||||
from experiment import init_model
|
||||
from omegaconf import OmegaConf
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.dataset.dataset_base import DatasetBase, FrameData
|
||||
@ -38,6 +37,8 @@ from pytorch3d.implicitron.tools.vis_utils import (
|
||||
)
|
||||
from tqdm import tqdm
|
||||
|
||||
from .experiment import init_model
|
||||
|
||||
|
||||
def render_sequence(
|
||||
dataset: DatasetBase,
|
||||
|
@ -9,6 +9,7 @@ import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@ -99,14 +100,14 @@ def save_model(model, stats, fl, optimizer=None, cfg=None):
|
||||
return flstats, flmodel, flopt
|
||||
|
||||
|
||||
def load_model(fl):
|
||||
def load_model(fl, map_location: Optional[dict]):
|
||||
flstats = get_stats_path(fl)
|
||||
flmodel = get_model_path(fl)
|
||||
flopt = get_optimizer_path(fl)
|
||||
model_state_dict = torch.load(flmodel)
|
||||
model_state_dict = torch.load(flmodel, map_location=map_location)
|
||||
stats = load_stats(flstats)
|
||||
if os.path.isfile(flopt):
|
||||
optimizer = torch.load(flopt)
|
||||
optimizer = torch.load(flopt, map_location=map_location)
|
||||
else:
|
||||
optimizer = None
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user