mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 06:10:34 +08:00
Replace pluggable components to create a proper Configurable hierarchy.
Summary:
This large diff rewrites a significant portion of Implicitron's config hierarchy. The new hierarchy, and some of the default implementation classes, are as follows:
```
Experiment
data_source: ImplicitronDataSource
dataset_map_provider
data_loader_map_provider
model_factory: ImplicitronModelFactory
model: GenericModel
optimizer_factory: ImplicitronOptimizerFactory
training_loop: ImplicitronTrainingLoop
evaluator: ImplicitronEvaluator
```
1) Experiment (used to be ExperimentConfig) is now a top-level Configurable and contains as members mainly (mostly new) high-level factory Configurables.
2) Experiment's job is to run factories, do some accelerate setup and then pass the results to the main training loop.
3) ImplicitronOptimizerFactory and ImplicitronModelFactory are new high-level factories that create the optimizer, scheduler, model, and stats objects.
4) TrainingLoop is a new configurable that runs the main training loop and the inner train-validate step.
5) Evaluator is a new configurable that TrainingLoop uses to run validation/test steps.
6) GenericModel is not the only model choice anymore. Instead, ImplicitronModelBase (by default instantiated with GenericModel) is a member of Experiment and can be easily replaced by a custom implementation by the user.
All the new Configurables are children of ReplaceableBase, and can be easily replaced with custom implementations.
In addition, I added support for the exponential LR schedule, updated the config files and the test, as well as added a config file that reproduces NERF results and a test to run the repro experiment.
Reviewed By: bottler
Differential Revision: D37723227
fbshipit-source-id: b36bee880d6aa53efdd2abfaae4489d8ab1e8a27
This commit is contained in:
committed by
Facebook GitHub Bot
parent
6b481595f0
commit
1b0584f7bd
@@ -1,49 +0,0 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from dataclasses import field
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from omegaconf import DictConfig
|
||||
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||
from pytorch3d.implicitron.tools.config import Configurable, get_default_args_field
|
||||
|
||||
from .optimization import init_optimizer
|
||||
|
||||
|
||||
class ExperimentConfig(Configurable):
|
||||
generic_model_args: DictConfig = get_default_args_field(GenericModel)
|
||||
solver_args: DictConfig = get_default_args_field(init_optimizer)
|
||||
data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
|
||||
architecture: str = "generic"
|
||||
detect_anomaly: bool = False
|
||||
eval_only: bool = False
|
||||
exp_dir: str = "./data/default_experiment/"
|
||||
exp_idx: int = 0
|
||||
gpu_idx: int = 0
|
||||
metric_print_interval: int = 5
|
||||
resume: bool = True
|
||||
resume_epoch: int = -1
|
||||
seed: int = 0
|
||||
store_checkpoints: bool = True
|
||||
store_checkpoints_purge: int = 1
|
||||
test_interval: int = -1
|
||||
test_when_finished: bool = False
|
||||
validation_interval: int = 1
|
||||
visdom_env: str = ""
|
||||
visdom_port: int = 8097
|
||||
visdom_server: str = "http://127.0.0.1"
|
||||
visualize_interval: int = 1000
|
||||
clip_grad: float = 0.0
|
||||
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
|
||||
|
||||
hydra: Dict[str, Any] = field(
|
||||
default_factory=lambda: {
|
||||
"run": {"dir": "."}, # Make hydra not change the working dir.
|
||||
"output_subdir": None, # disable storing the .hydra logs
|
||||
}
|
||||
)
|
||||
199
projects/implicitron_trainer/impl/model_factory.py
Normal file
199
projects/implicitron_trainer/impl/model_factory.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import torch.optim
|
||||
|
||||
from accelerate import Accelerator
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.stats import Stats
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelFactoryBase(ReplaceableBase):
|
||||
def __call__(self, **kwargs) -> ImplicitronModelBase:
|
||||
"""
|
||||
Initialize the model (possibly from a previously saved state).
|
||||
|
||||
Returns: An instance of ImplicitronModelBase.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_stats(self, **kwargs) -> Stats:
|
||||
"""
|
||||
Initialize or load a Stats object.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
||||
"""
|
||||
A factory class that initializes an implicit rendering model.
|
||||
|
||||
Members:
|
||||
force_load: If True, throw a FileNotFoundError if `resume` is True but
|
||||
a model checkpoint cannot be found.
|
||||
model: An ImplicitronModelBase object.
|
||||
resume: If True, attempt to load the last checkpoint from `exp_dir`
|
||||
passed to __call__. Failure to do so will return a model with ini-
|
||||
tial weights unless `force_load` is True.
|
||||
resume_epoch: If `resume` is True: Resume a model at this epoch, or if
|
||||
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
||||
visdom_env: The name of the Visdom environment to use for plotting.
|
||||
visdom_port: The Visdom port.
|
||||
visdom_server: Address of the Visdom server.
|
||||
"""
|
||||
|
||||
force_load: bool = False
|
||||
model: ImplicitronModelBase
|
||||
model_class_type: str = "GenericModel"
|
||||
resume: bool = False
|
||||
resume_epoch: int = -1
|
||||
visdom_env: str = ""
|
||||
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
|
||||
visdom_server: str = "http://127.0.0.1"
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
exp_dir: str,
|
||||
accelerator: Optional[Accelerator] = None,
|
||||
) -> ImplicitronModelBase:
|
||||
"""
|
||||
Returns an instance of `ImplicitronModelBase`, possibly loaded from a
|
||||
checkpoint (if self.resume, self.resume_epoch specify so).
|
||||
|
||||
Args:
|
||||
exp_dir: Root experiment directory.
|
||||
accelerator: An Accelerator object.
|
||||
|
||||
Returns:
|
||||
model: The model with optionally loaded weights from checkpoint
|
||||
|
||||
Raise:
|
||||
FileNotFoundError if `force_load` is True but checkpoint not found.
|
||||
"""
|
||||
# Determine the network outputs that should be logged
|
||||
if hasattr(self.model, "log_vars"):
|
||||
log_vars = list(self.model.log_vars) # pyre-ignore [6]
|
||||
else:
|
||||
log_vars = ["objective"]
|
||||
|
||||
# Retrieve the last checkpoint
|
||||
if self.resume_epoch > 0:
|
||||
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
|
||||
else:
|
||||
model_path = model_io.find_last_checkpoint(exp_dir)
|
||||
|
||||
if model_path is not None:
|
||||
logger.info("found previous model %s" % model_path)
|
||||
if self.force_load or self.resume:
|
||||
logger.info(" -> resuming")
|
||||
|
||||
map_location = None
|
||||
if accelerator is not None and not accelerator.is_local_main_process:
|
||||
map_location = {
|
||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||
}
|
||||
model_state_dict = torch.load(
|
||||
model_io.get_model_path(model_path), map_location=map_location
|
||||
)
|
||||
|
||||
try:
|
||||
self.model.load_state_dict(model_state_dict, strict=True)
|
||||
except RuntimeError as e:
|
||||
logger.error(e)
|
||||
logger.info(
|
||||
"Cant load state dict in strict mode! -> trying non-strict"
|
||||
)
|
||||
self.model.load_state_dict(model_state_dict, strict=False)
|
||||
self.model.log_vars = log_vars # pyre-ignore [16]
|
||||
else:
|
||||
logger.info(" -> but not resuming -> starting from scratch")
|
||||
elif self.force_load:
|
||||
raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!")
|
||||
|
||||
return self.model
|
||||
|
||||
def load_stats(
|
||||
self,
|
||||
log_vars: List[str],
|
||||
exp_dir: str,
|
||||
clear_stats: bool = False,
|
||||
**kwargs,
|
||||
) -> Stats:
|
||||
"""
|
||||
Load Stats that correspond to the model's log_vars.
|
||||
|
||||
Args:
|
||||
log_vars: A list of variable names to log. Should be a subset of the
|
||||
`preds` returned by the forward function of the corresponding
|
||||
ImplicitronModelBase instance.
|
||||
exp_dir: Root experiment directory.
|
||||
clear_stats: If True, do not load stats from the checkpoint speci-
|
||||
fied by self.resume and self.resume_epoch; instead, create a
|
||||
fresh stats object.
|
||||
|
||||
stats: The stats structure (optionally loaded from checkpoint)
|
||||
"""
|
||||
# Init the stats struct
|
||||
visdom_env_charts = (
|
||||
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
|
||||
)
|
||||
stats = Stats(
|
||||
# log_vars should be a list, but OmegaConf might load them as ListConfig
|
||||
list(log_vars),
|
||||
visdom_env=visdom_env_charts,
|
||||
verbose=False,
|
||||
visdom_server=self.visdom_server,
|
||||
visdom_port=self.visdom_port,
|
||||
)
|
||||
if self.resume_epoch > 0:
|
||||
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
|
||||
else:
|
||||
model_path = model_io.find_last_checkpoint(exp_dir)
|
||||
|
||||
if model_path is not None:
|
||||
stats_path = model_io.get_stats_path(model_path)
|
||||
stats_load = model_io.load_stats(stats_path)
|
||||
|
||||
# Determine if stats should be reset
|
||||
if not clear_stats:
|
||||
if stats_load is None:
|
||||
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
|
||||
last_epoch = model_io.parse_epoch_from_model_path(model_path)
|
||||
logger.info(f"Estimated resume epoch = {last_epoch}")
|
||||
|
||||
# Reset the stats struct
|
||||
for _ in range(last_epoch + 1):
|
||||
stats.new_epoch()
|
||||
assert last_epoch == stats.epoch
|
||||
else:
|
||||
stats = stats_load
|
||||
|
||||
# Update stats properties incase it was reset on load
|
||||
stats.visdom_env = visdom_env_charts
|
||||
stats.visdom_server = self.visdom_server
|
||||
stats.visdom_port = self.visdom_port
|
||||
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
|
||||
stats.synchronize_logged_vars(log_vars)
|
||||
else:
|
||||
logger.info(" -> clearing stats")
|
||||
|
||||
return stats
|
||||
@@ -1,109 +0,0 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from pytorch3d.implicitron.models.generic_model import GenericModel
|
||||
from pytorch3d.implicitron.tools.config import enable_get_default_args
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def init_optimizer(
|
||||
model: GenericModel,
|
||||
optimizer_state: Optional[Dict[str, Any]],
|
||||
last_epoch: int,
|
||||
breed: str = "adam",
|
||||
weight_decay: float = 0.0,
|
||||
lr_policy: str = "multistep",
|
||||
lr: float = 0.0005,
|
||||
gamma: float = 0.1,
|
||||
momentum: float = 0.9,
|
||||
betas: Tuple[float, ...] = (0.9, 0.999),
|
||||
milestones: Tuple[int, ...] = (),
|
||||
max_epochs: int = 1000,
|
||||
):
|
||||
"""
|
||||
Initialize the optimizer (optionally from checkpoint state)
|
||||
and the learning rate scheduler.
|
||||
|
||||
Args:
|
||||
model: The model with optionally loaded weights
|
||||
optimizer_state: The state dict for the optimizer. If None
|
||||
it has not been loaded from checkpoint
|
||||
last_epoch: If the model was loaded from checkpoint this will be the
|
||||
number of the last epoch that was saved
|
||||
breed: The type of optimizer to use e.g. adam
|
||||
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
|
||||
lr_policy: The policy to use for learning rate. Currently, only "multistep:
|
||||
is supported.
|
||||
lr: The value for the initial learning rate
|
||||
gamma: Multiplicative factor of learning rate decay
|
||||
momentum: Momentum factor for SGD optimizer
|
||||
betas: Coefficients used for computing running averages of gradient and its square
|
||||
in the Adam optimizer
|
||||
milestones: List of increasing epoch indices at which the learning rate is
|
||||
modified
|
||||
max_epochs: The maximum number of epochs to run the optimizer for
|
||||
|
||||
Returns:
|
||||
optimizer: Optimizer module, optionally loaded from checkpoint
|
||||
scheduler: Learning rate scheduler module
|
||||
|
||||
Raise:
|
||||
ValueError if `breed` or `lr_policy` are not supported.
|
||||
"""
|
||||
|
||||
# Get the parameters to optimize
|
||||
if hasattr(model, "_get_param_groups"): # use the model function
|
||||
# pyre-ignore[29]
|
||||
p_groups = model._get_param_groups(lr, wd=weight_decay)
|
||||
else:
|
||||
allprm = [prm for prm in model.parameters() if prm.requires_grad]
|
||||
p_groups = [{"params": allprm, "lr": lr}]
|
||||
|
||||
# Intialize the optimizer
|
||||
if breed == "sgd":
|
||||
optimizer = torch.optim.SGD(
|
||||
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
|
||||
)
|
||||
elif breed == "adagrad":
|
||||
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
|
||||
elif breed == "adam":
|
||||
optimizer = torch.optim.Adam(
|
||||
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
|
||||
)
|
||||
else:
|
||||
raise ValueError("no such solver type %s" % breed)
|
||||
logger.info(" -> solver type = %s" % breed)
|
||||
|
||||
# Load state from checkpoint
|
||||
if optimizer_state is not None:
|
||||
logger.info(" -> setting loaded optimizer state")
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
# Initialize the learning rate scheduler
|
||||
if lr_policy == "multistep":
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer,
|
||||
milestones=milestones,
|
||||
gamma=gamma,
|
||||
)
|
||||
else:
|
||||
raise ValueError("no such lr policy %s" % lr_policy)
|
||||
|
||||
# When loading from checkpoint, this will make sure that the
|
||||
# lr is correctly set even after returning
|
||||
for _ in range(last_epoch):
|
||||
scheduler.step()
|
||||
|
||||
optimizer.zero_grad()
|
||||
return optimizer, scheduler
|
||||
|
||||
|
||||
enable_get_default_args(init_optimizer)
|
||||
197
projects/implicitron_trainer/impl/optimizer_factory.py
Normal file
197
projects/implicitron_trainer/impl/optimizer_factory.py
Normal file
@@ -0,0 +1,197 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch.optim
|
||||
|
||||
from accelerate import Accelerator
|
||||
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||
from pytorch3d.implicitron.tools import model_io
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OptimizerFactoryBase(ReplaceableBase):
|
||||
def __call__(
|
||||
self, model: ImplicitronModelBase, **kwargs
|
||||
) -> Tuple[torch.optim.Optimizer, Any]:
|
||||
"""
|
||||
Initialize the optimizer and lr scheduler.
|
||||
|
||||
Args:
|
||||
model: The model with optionally loaded weights.
|
||||
|
||||
Returns:
|
||||
An optimizer module (optionally loaded from a checkpoint) and
|
||||
a learning rate scheduler module (should be a subclass of torch.optim's
|
||||
lr_scheduler._LRScheduler).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
"""
|
||||
A factory that initializes the optimizer and lr scheduler.
|
||||
|
||||
Members:
|
||||
betas: Beta parameters for the Adam optimizer.
|
||||
breed: The type of optimizer to use. We currently support SGD, Adagrad
|
||||
and Adam.
|
||||
exponential_lr_step_size: With Exponential policy only,
|
||||
lr = lr * gamma ** (epoch/step_size)
|
||||
gamma: Multiplicative factor of learning rate decay.
|
||||
lr: The value for the initial learning rate.
|
||||
lr_policy: The policy to use for learning rate. We currently support
|
||||
MultiStepLR and Exponential policies.
|
||||
momentum: A momentum value (for SGD only).
|
||||
multistep_lr_milestones: With MultiStepLR policy only: list of
|
||||
increasing epoch indices at which the learning rate is modified.
|
||||
momentum: Momentum factor for SGD optimizer.
|
||||
resume: If True, attempt to load the last checkpoint from `exp_dir`
|
||||
passed to __call__. Failure to do so will return a newly initialized
|
||||
optimizer.
|
||||
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
|
||||
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
||||
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
|
||||
"""
|
||||
|
||||
betas: Tuple[float, ...] = (0.9, 0.999)
|
||||
breed: str = "Adam"
|
||||
exponential_lr_step_size: int = 250
|
||||
gamma: float = 0.1
|
||||
lr: float = 0.0005
|
||||
lr_policy: str = "MultiStepLR"
|
||||
momentum: float = 0.9
|
||||
multistep_lr_milestones: tuple = ()
|
||||
resume: bool = False
|
||||
resume_epoch: int = -1
|
||||
weight_decay: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
last_epoch: int,
|
||||
model: ImplicitronModelBase,
|
||||
accelerator: Optional[Accelerator] = None,
|
||||
exp_dir: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.optim.Optimizer, Any]:
|
||||
"""
|
||||
Initialize the optimizer (optionally from a checkpoint) and the lr scheduluer.
|
||||
|
||||
Args:
|
||||
last_epoch: If the model was loaded from checkpoint this will be the
|
||||
number of the last epoch that was saved.
|
||||
model: The model with optionally loaded weights.
|
||||
accelerator: An optional Accelerator instance.
|
||||
exp_dir: Root experiment directory.
|
||||
|
||||
Returns:
|
||||
An optimizer module (optionally loaded from a checkpoint) and
|
||||
a learning rate scheduler module (should be a subclass of torch.optim's
|
||||
lr_scheduler._LRScheduler).
|
||||
"""
|
||||
# Get the parameters to optimize
|
||||
if hasattr(model, "_get_param_groups"): # use the model function
|
||||
# pyre-ignore[29]
|
||||
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
|
||||
else:
|
||||
allprm = [prm for prm in model.parameters() if prm.requires_grad]
|
||||
p_groups = [{"params": allprm, "lr": self.lr}]
|
||||
|
||||
# Intialize the optimizer
|
||||
if self.breed == "SGD":
|
||||
optimizer = torch.optim.SGD(
|
||||
p_groups,
|
||||
lr=self.lr,
|
||||
momentum=self.momentum,
|
||||
weight_decay=self.weight_decay,
|
||||
)
|
||||
elif self.breed == "Adagrad":
|
||||
optimizer = torch.optim.Adagrad(
|
||||
p_groups, lr=self.lr, weight_decay=self.weight_decay
|
||||
)
|
||||
elif self.breed == "Adam":
|
||||
optimizer = torch.optim.Adam(
|
||||
p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay
|
||||
)
|
||||
else:
|
||||
raise ValueError("no such solver type %s" % self.breed)
|
||||
logger.info(" -> solver type = %s" % self.breed)
|
||||
|
||||
# Load state from checkpoint
|
||||
optimizer_state = self._get_optimizer_state(exp_dir, accelerator)
|
||||
if optimizer_state is not None:
|
||||
logger.info(" -> setting loaded optimizer state")
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
# Initialize the learning rate scheduler
|
||||
if self.lr_policy.casefold() == "MultiStepLR".casefold():
|
||||
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
||||
optimizer,
|
||||
milestones=self.multistep_lr_milestones,
|
||||
gamma=self.gamma,
|
||||
)
|
||||
elif self.lr_policy.casefold() == "Exponential".casefold():
|
||||
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
||||
optimizer,
|
||||
lambda epoch: self.gamma ** (epoch / self.exponential_lr_step_size),
|
||||
verbose=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError("no such lr policy %s" % self.lr_policy)
|
||||
|
||||
# When loading from checkpoint, this will make sure that the
|
||||
# lr is correctly set even after returning.
|
||||
for _ in range(last_epoch):
|
||||
scheduler.step()
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
return optimizer, scheduler
|
||||
|
||||
def _get_optimizer_state(
|
||||
self,
|
||||
exp_dir: Optional[str],
|
||||
accelerator: Optional[Accelerator] = None,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Load an optimizer state from a checkpoint.
|
||||
"""
|
||||
if exp_dir is None or not self.resume:
|
||||
return None
|
||||
if self.resume_epoch > 0:
|
||||
save_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
|
||||
else:
|
||||
save_path = model_io.find_last_checkpoint(exp_dir)
|
||||
optimizer_state = None
|
||||
if save_path is not None:
|
||||
logger.info(f"Found previous optimizer state {save_path}.")
|
||||
logger.info(" -> resuming")
|
||||
opt_path = model_io.get_optimizer_path(save_path)
|
||||
|
||||
if os.path.isfile(opt_path):
|
||||
map_location = None
|
||||
if accelerator is not None and not accelerator.is_local_main_process:
|
||||
map_location = {
|
||||
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
|
||||
}
|
||||
optimizer_state = torch.load(opt_path, map_location)
|
||||
else:
|
||||
optimizer_state = None
|
||||
return optimizer_state
|
||||
365
projects/implicitron_trainer/impl/training_loop.py
Normal file
365
projects/implicitron_trainer/impl/training_loop.py
Normal file
@@ -0,0 +1,365 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from pytorch3d.implicitron.dataset.data_source import Task
|
||||
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||
from pytorch3d.implicitron.models.generic_model import EvaluationMode
|
||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
run_auto_creation,
|
||||
)
|
||||
from pytorch3d.implicitron.tools.stats import Stats
|
||||
from pytorch3d.renderer.cameras import CamerasBase
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TrainingLoopBase(ReplaceableBase):
|
||||
def run(
|
||||
self,
|
||||
train_loader: DataLoader,
|
||||
val_loader: Optional[DataLoader],
|
||||
test_loader: Optional[DataLoader],
|
||||
model: ImplicitronModelBase,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: Any,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
"""
|
||||
Members:
|
||||
eval_only: If True, only run evaluation using the test dataloader.
|
||||
evaluator: An EvaluatorBase instance, used to evaluate training results.
|
||||
max_epochs: Train for this many epochs. Note that if the model was
|
||||
loaded from a checkpoint, we will restart training at the appropriate
|
||||
epoch and run for (max_epochs - checkpoint_epoch) epochs.
|
||||
seed: A random seed to ensure reproducibility.
|
||||
store_checkpoints: If True, store model and optimizer state checkpoints.
|
||||
store_checkpoints_purge: If >= 0, remove any checkpoints older or equal
|
||||
to this many epochs.
|
||||
test_interval: Evaluate on a test dataloader each `test_interval` epochs.
|
||||
test_when_finished: If True, evaluate on a test dataloader when training
|
||||
completes.
|
||||
validation_interval: Validate each `validation_interval` epochs.
|
||||
clip_grad: Optionally clip the gradient norms.
|
||||
If set to a value <=0.0, no clipping
|
||||
metric_print_interval: The batch interval at which the stats should be
|
||||
logged.
|
||||
visualize_interval: The batch interval at which the visualizations
|
||||
should be plotted
|
||||
"""
|
||||
|
||||
# Parameters of the outer training loop.
|
||||
eval_only: bool = False
|
||||
evaluator: EvaluatorBase
|
||||
evaluator_class_type: str = "ImplicitronEvaluator"
|
||||
max_epochs: int = 1000
|
||||
seed: int = 0
|
||||
store_checkpoints: bool = True
|
||||
store_checkpoints_purge: int = 1
|
||||
test_interval: int = -1
|
||||
test_when_finished: bool = False
|
||||
validation_interval: int = 1
|
||||
|
||||
# Parameters of a single training-validation step.
|
||||
clip_grad: float = 0.0
|
||||
metric_print_interval: int = 5
|
||||
visualize_interval: int = 1000
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
train_loader: DataLoader,
|
||||
val_loader: Optional[DataLoader],
|
||||
test_loader: Optional[DataLoader],
|
||||
model: ImplicitronModelBase,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: Any,
|
||||
accelerator: Optional[Accelerator],
|
||||
all_train_cameras: Optional[CamerasBase],
|
||||
device: torch.device,
|
||||
exp_dir: str,
|
||||
stats: Stats,
|
||||
task: Task,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Entry point to run the training and validation loops
|
||||
based on the specified config file.
|
||||
"""
|
||||
_seed_all_random_engines(self.seed)
|
||||
start_epoch = stats.epoch + 1
|
||||
assert scheduler.last_epoch == stats.epoch + 1
|
||||
assert scheduler.last_epoch == start_epoch
|
||||
|
||||
# only run evaluation on the test dataloader
|
||||
if self.eval_only:
|
||||
if test_loader is not None:
|
||||
self.evaluator.run(
|
||||
all_train_cameras=all_train_cameras,
|
||||
dataloader=test_loader,
|
||||
device=device,
|
||||
dump_to_json=True,
|
||||
epoch=stats.epoch,
|
||||
exp_dir=exp_dir,
|
||||
model=model,
|
||||
task=task,
|
||||
)
|
||||
return
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot evaluate and dump results to json, no test data provided."
|
||||
)
|
||||
|
||||
# loop through epochs
|
||||
for epoch in range(start_epoch, self.max_epochs):
|
||||
# automatic new_epoch and plotting of stats at every epoch start
|
||||
with stats:
|
||||
|
||||
# Make sure to re-seed random generators to ensure reproducibility
|
||||
# even after restart.
|
||||
_seed_all_random_engines(self.seed + epoch)
|
||||
|
||||
cur_lr = float(scheduler.get_last_lr()[-1])
|
||||
logger.debug(f"scheduler lr = {cur_lr:1.2e}")
|
||||
|
||||
# train loop
|
||||
self._training_or_validation_epoch(
|
||||
accelerator=accelerator,
|
||||
device=device,
|
||||
epoch=epoch,
|
||||
loader=train_loader,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
stats=stats,
|
||||
validation=False,
|
||||
)
|
||||
|
||||
# val loop (optional)
|
||||
if val_loader is not None and epoch % self.validation_interval == 0:
|
||||
self._training_or_validation_epoch(
|
||||
accelerator=accelerator,
|
||||
device=device,
|
||||
epoch=epoch,
|
||||
loader=val_loader,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
stats=stats,
|
||||
validation=True,
|
||||
)
|
||||
|
||||
# eval loop (optional)
|
||||
if (
|
||||
test_loader is not None
|
||||
and self.test_interval > 0
|
||||
and epoch % self.test_interval == 0
|
||||
):
|
||||
self.evaluator.run(
|
||||
all_train_cameras=all_train_cameras,
|
||||
device=device,
|
||||
dataloader=test_loader,
|
||||
model=model,
|
||||
task=task,
|
||||
)
|
||||
|
||||
assert stats.epoch == epoch, "inconsistent stats!"
|
||||
self._checkpoint(accelerator, epoch, exp_dir, model, optimizer, stats)
|
||||
|
||||
scheduler.step()
|
||||
new_lr = float(scheduler.get_last_lr()[-1])
|
||||
if new_lr != cur_lr:
|
||||
logger.info(f"LR change! {cur_lr} -> {new_lr}")
|
||||
|
||||
if self.test_when_finished:
|
||||
if test_loader is not None:
|
||||
self.evaluator.run(
|
||||
all_train_cameras=all_train_cameras,
|
||||
device=device,
|
||||
dump_to_json=True,
|
||||
epoch=stats.epoch,
|
||||
exp_dir=exp_dir,
|
||||
dataloader=test_loader,
|
||||
model=model,
|
||||
task=task,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot evaluate and dump results to json, no test data provided."
|
||||
)
|
||||
|
||||
def _training_or_validation_epoch(
|
||||
self,
|
||||
epoch: int,
|
||||
loader: DataLoader,
|
||||
model: ImplicitronModelBase,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
stats: Stats,
|
||||
validation: bool,
|
||||
*,
|
||||
accelerator: Optional[Accelerator],
|
||||
bp_var: str = "objective",
|
||||
device: torch.device,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
This is the main loop for training and evaluation including:
|
||||
model forward pass, loss computation, backward pass and visualization.
|
||||
|
||||
Args:
|
||||
epoch: The index of the current epoch
|
||||
loader: The dataloader to use for the loop
|
||||
model: The model module optionally loaded from checkpoint
|
||||
optimizer: The optimizer module optionally loaded from checkpoint
|
||||
stats: The stats struct, also optionally loaded from checkpoint
|
||||
validation: If true, run the loop with the model in eval mode
|
||||
and skip the backward pass
|
||||
accelerator: An optional Accelerator instance.
|
||||
bp_var: The name of the key in the model output `preds` dict which
|
||||
should be used as the loss for the backward pass.
|
||||
device: The device on which to run the model.
|
||||
"""
|
||||
|
||||
if validation:
|
||||
model.eval()
|
||||
trainmode = "val"
|
||||
else:
|
||||
model.train()
|
||||
trainmode = "train"
|
||||
|
||||
t_start = time.time()
|
||||
|
||||
# get the visdom env name
|
||||
visdom_env_imgs = stats.visdom_env + "_images_" + trainmode
|
||||
viz = vis_utils.get_visdom_connection(
|
||||
server=stats.visdom_server,
|
||||
port=stats.visdom_port,
|
||||
)
|
||||
|
||||
# Iterate through the batches
|
||||
n_batches = len(loader)
|
||||
for it, net_input in enumerate(loader):
|
||||
last_iter = it == n_batches - 1
|
||||
|
||||
# move to gpu where possible (in place)
|
||||
net_input = net_input.to(device)
|
||||
|
||||
# run the forward pass
|
||||
if not validation:
|
||||
optimizer.zero_grad()
|
||||
preds = model(
|
||||
**{**net_input, "evaluation_mode": EvaluationMode.TRAINING}
|
||||
)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
preds = model(
|
||||
**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}
|
||||
)
|
||||
|
||||
# make sure we dont overwrite something
|
||||
assert all(k not in preds for k in net_input.keys())
|
||||
# merge everything into one big dict
|
||||
preds.update(net_input)
|
||||
|
||||
# update the stats logger
|
||||
stats.update(preds, time_start=t_start, stat_set=trainmode)
|
||||
# pyre-ignore [16]
|
||||
assert stats.it[trainmode] == it, "inconsistent stat iteration number!"
|
||||
|
||||
# print textual status update
|
||||
if it % self.metric_print_interval == 0 or last_iter:
|
||||
stats.print(stat_set=trainmode, max_it=n_batches)
|
||||
|
||||
# visualize results
|
||||
if (
|
||||
(accelerator is None or accelerator.is_local_main_process)
|
||||
and self.visualize_interval > 0
|
||||
and it % self.visualize_interval == 0
|
||||
):
|
||||
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
||||
if hasattr(model, "visualize"):
|
||||
# pyre-ignore [29]
|
||||
model.visualize(
|
||||
viz,
|
||||
visdom_env_imgs,
|
||||
preds,
|
||||
prefix,
|
||||
)
|
||||
|
||||
# optimizer step
|
||||
if not validation:
|
||||
loss = preds[bp_var]
|
||||
assert torch.isfinite(loss).all(), "Non-finite loss!"
|
||||
# backprop
|
||||
if accelerator is None:
|
||||
loss.backward()
|
||||
else:
|
||||
accelerator.backward(loss)
|
||||
if self.clip_grad > 0.0:
|
||||
# Optionally clip the gradient norms.
|
||||
total_norm = torch.nn.utils.clip_grad_norm(
|
||||
model.parameters(), self.clip_grad
|
||||
)
|
||||
if total_norm > self.clip_grad:
|
||||
logger.debug(
|
||||
f"Clipping gradient: {total_norm}"
|
||||
+ f" with coef {self.clip_grad / float(total_norm)}."
|
||||
)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
def _checkpoint(
|
||||
self,
|
||||
accelerator: Optional[Accelerator],
|
||||
epoch: int,
|
||||
exp_dir: str,
|
||||
model: ImplicitronModelBase,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
stats: Stats,
|
||||
):
|
||||
"""
|
||||
Save a model and its corresponding Stats object to a file, if
|
||||
`self.store_checkpoints` is True. In addition, if
|
||||
`self.store_checkpoints_purge` is True, remove any checkpoints older
|
||||
than `self.store_checkpoints_purge` epochs old.
|
||||
"""
|
||||
if self.store_checkpoints and (
|
||||
accelerator is None or accelerator.is_local_main_process
|
||||
):
|
||||
if self.store_checkpoints_purge > 0:
|
||||
for prev_epoch in range(epoch - self.store_checkpoints_purge):
|
||||
model_io.purge_epoch(exp_dir, prev_epoch)
|
||||
outfile = model_io.get_checkpoint(exp_dir, epoch)
|
||||
unwrapped_model = (
|
||||
model if accelerator is None else accelerator.unwrap_model(model)
|
||||
)
|
||||
model_io.safe_save_model(
|
||||
unwrapped_model, stats, outfile, optimizer=optimizer
|
||||
)
|
||||
|
||||
|
||||
def _seed_all_random_engines(seed: int) -> None:
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
Reference in New Issue
Block a user