Replace pluggable components to create a proper Configurable hierarchy.

Summary:
This large diff rewrites a significant portion of Implicitron's config hierarchy. The new hierarchy, and some of the default implementation classes, are as follows:
```
Experiment
    data_source: ImplicitronDataSource
        dataset_map_provider
        data_loader_map_provider
    model_factory: ImplicitronModelFactory
        model: GenericModel
    optimizer_factory: ImplicitronOptimizerFactory
    training_loop: ImplicitronTrainingLoop
        evaluator: ImplicitronEvaluator
```

1) Experiment (used to be ExperimentConfig) is now a top-level Configurable and contains as members mainly (mostly new) high-level factory Configurables.
2) Experiment's job is to run factories, do some accelerate setup and then pass the results to the main training loop.
3) ImplicitronOptimizerFactory and ImplicitronModelFactory are new high-level factories that create the optimizer, scheduler, model, and stats objects.
4) TrainingLoop is a new configurable that runs the main training loop and the inner train-validate step.
5) Evaluator is a new configurable that TrainingLoop uses to run validation/test steps.
6) GenericModel is not the only model choice anymore. Instead, ImplicitronModelBase (by default instantiated with GenericModel) is a member of Experiment and can be easily replaced by a custom implementation by the user.

All the new Configurables are children of ReplaceableBase, and can be easily replaced with custom implementations.

In addition, I added support for the exponential LR schedule, updated the config files and the test, as well as added a config file that reproduces NERF results and a test to run the repro experiment.

Reviewed By: bottler

Differential Revision: D37723227

fbshipit-source-id: b36bee880d6aa53efdd2abfaae4489d8ab1e8a27
This commit is contained in:
Krzysztof Chalupka
2022-07-29 17:32:51 -07:00
committed by Facebook GitHub Bot
parent 6b481595f0
commit 1b0584f7bd
42 changed files with 2045 additions and 1478 deletions

View File

@@ -1,49 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import field
from typing import Any, Dict, Tuple
from omegaconf import DictConfig
from pytorch3d.implicitron.dataset.data_source import ImplicitronDataSource
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.tools.config import Configurable, get_default_args_field
from .optimization import init_optimizer
class ExperimentConfig(Configurable):
generic_model_args: DictConfig = get_default_args_field(GenericModel)
solver_args: DictConfig = get_default_args_field(init_optimizer)
data_source_args: DictConfig = get_default_args_field(ImplicitronDataSource)
architecture: str = "generic"
detect_anomaly: bool = False
eval_only: bool = False
exp_dir: str = "./data/default_experiment/"
exp_idx: int = 0
gpu_idx: int = 0
metric_print_interval: int = 5
resume: bool = True
resume_epoch: int = -1
seed: int = 0
store_checkpoints: bool = True
store_checkpoints_purge: int = 1
test_interval: int = -1
test_when_finished: bool = False
validation_interval: int = 1
visdom_env: str = ""
visdom_port: int = 8097
visdom_server: str = "http://127.0.0.1"
visualize_interval: int = 1000
clip_grad: float = 0.0
camera_difficulty_bin_breaks: Tuple[float, ...] = 0.97, 0.98
hydra: Dict[str, Any] = field(
default_factory=lambda: {
"run": {"dir": "."}, # Make hydra not change the working dir.
"output_subdir": None, # disable storing the .hydra logs
}
)

View File

@@ -0,0 +1,199 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from typing import List, Optional
import torch.optim
from accelerate import Accelerator
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.implicitron.tools.stats import Stats
logger = logging.getLogger(__name__)
class ModelFactoryBase(ReplaceableBase):
def __call__(self, **kwargs) -> ImplicitronModelBase:
"""
Initialize the model (possibly from a previously saved state).
Returns: An instance of ImplicitronModelBase.
"""
raise NotImplementedError()
def load_stats(self, **kwargs) -> Stats:
"""
Initialize or load a Stats object.
"""
raise NotImplementedError()
@registry.register
class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
"""
A factory class that initializes an implicit rendering model.
Members:
force_load: If True, throw a FileNotFoundError if `resume` is True but
a model checkpoint cannot be found.
model: An ImplicitronModelBase object.
resume: If True, attempt to load the last checkpoint from `exp_dir`
passed to __call__. Failure to do so will return a model with ini-
tial weights unless `force_load` is True.
resume_epoch: If `resume` is True: Resume a model at this epoch, or if
`resume_epoch` <= 0, then resume from the latest checkpoint.
visdom_env: The name of the Visdom environment to use for plotting.
visdom_port: The Visdom port.
visdom_server: Address of the Visdom server.
"""
force_load: bool = False
model: ImplicitronModelBase
model_class_type: str = "GenericModel"
resume: bool = False
resume_epoch: int = -1
visdom_env: str = ""
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
visdom_server: str = "http://127.0.0.1"
def __post_init__(self):
run_auto_creation(self)
def __call__(
self,
exp_dir: str,
accelerator: Optional[Accelerator] = None,
) -> ImplicitronModelBase:
"""
Returns an instance of `ImplicitronModelBase`, possibly loaded from a
checkpoint (if self.resume, self.resume_epoch specify so).
Args:
exp_dir: Root experiment directory.
accelerator: An Accelerator object.
Returns:
model: The model with optionally loaded weights from checkpoint
Raise:
FileNotFoundError if `force_load` is True but checkpoint not found.
"""
# Determine the network outputs that should be logged
if hasattr(self.model, "log_vars"):
log_vars = list(self.model.log_vars) # pyre-ignore [6]
else:
log_vars = ["objective"]
# Retrieve the last checkpoint
if self.resume_epoch > 0:
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
else:
model_path = model_io.find_last_checkpoint(exp_dir)
if model_path is not None:
logger.info("found previous model %s" % model_path)
if self.force_load or self.resume:
logger.info(" -> resuming")
map_location = None
if accelerator is not None and not accelerator.is_local_main_process:
map_location = {
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
}
model_state_dict = torch.load(
model_io.get_model_path(model_path), map_location=map_location
)
try:
self.model.load_state_dict(model_state_dict, strict=True)
except RuntimeError as e:
logger.error(e)
logger.info(
"Cant load state dict in strict mode! -> trying non-strict"
)
self.model.load_state_dict(model_state_dict, strict=False)
self.model.log_vars = log_vars # pyre-ignore [16]
else:
logger.info(" -> but not resuming -> starting from scratch")
elif self.force_load:
raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!")
return self.model
def load_stats(
self,
log_vars: List[str],
exp_dir: str,
clear_stats: bool = False,
**kwargs,
) -> Stats:
"""
Load Stats that correspond to the model's log_vars.
Args:
log_vars: A list of variable names to log. Should be a subset of the
`preds` returned by the forward function of the corresponding
ImplicitronModelBase instance.
exp_dir: Root experiment directory.
clear_stats: If True, do not load stats from the checkpoint speci-
fied by self.resume and self.resume_epoch; instead, create a
fresh stats object.
stats: The stats structure (optionally loaded from checkpoint)
"""
# Init the stats struct
visdom_env_charts = (
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
)
stats = Stats(
# log_vars should be a list, but OmegaConf might load them as ListConfig
list(log_vars),
visdom_env=visdom_env_charts,
verbose=False,
visdom_server=self.visdom_server,
visdom_port=self.visdom_port,
)
if self.resume_epoch > 0:
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
else:
model_path = model_io.find_last_checkpoint(exp_dir)
if model_path is not None:
stats_path = model_io.get_stats_path(model_path)
stats_load = model_io.load_stats(stats_path)
# Determine if stats should be reset
if not clear_stats:
if stats_load is None:
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
last_epoch = model_io.parse_epoch_from_model_path(model_path)
logger.info(f"Estimated resume epoch = {last_epoch}")
# Reset the stats struct
for _ in range(last_epoch + 1):
stats.new_epoch()
assert last_epoch == stats.epoch
else:
stats = stats_load
# Update stats properties incase it was reset on load
stats.visdom_env = visdom_env_charts
stats.visdom_server = self.visdom_server
stats.visdom_port = self.visdom_port
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
stats.synchronize_logged_vars(log_vars)
else:
logger.info(" -> clearing stats")
return stats

View File

@@ -1,109 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Any, Dict, Optional, Tuple
import torch
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.tools.config import enable_get_default_args
logger = logging.getLogger(__name__)
def init_optimizer(
model: GenericModel,
optimizer_state: Optional[Dict[str, Any]],
last_epoch: int,
breed: str = "adam",
weight_decay: float = 0.0,
lr_policy: str = "multistep",
lr: float = 0.0005,
gamma: float = 0.1,
momentum: float = 0.9,
betas: Tuple[float, ...] = (0.9, 0.999),
milestones: Tuple[int, ...] = (),
max_epochs: int = 1000,
):
"""
Initialize the optimizer (optionally from checkpoint state)
and the learning rate scheduler.
Args:
model: The model with optionally loaded weights
optimizer_state: The state dict for the optimizer. If None
it has not been loaded from checkpoint
last_epoch: If the model was loaded from checkpoint this will be the
number of the last epoch that was saved
breed: The type of optimizer to use e.g. adam
weight_decay: The optimizer weight_decay (L2 penalty on model weights)
lr_policy: The policy to use for learning rate. Currently, only "multistep:
is supported.
lr: The value for the initial learning rate
gamma: Multiplicative factor of learning rate decay
momentum: Momentum factor for SGD optimizer
betas: Coefficients used for computing running averages of gradient and its square
in the Adam optimizer
milestones: List of increasing epoch indices at which the learning rate is
modified
max_epochs: The maximum number of epochs to run the optimizer for
Returns:
optimizer: Optimizer module, optionally loaded from checkpoint
scheduler: Learning rate scheduler module
Raise:
ValueError if `breed` or `lr_policy` are not supported.
"""
# Get the parameters to optimize
if hasattr(model, "_get_param_groups"): # use the model function
# pyre-ignore[29]
p_groups = model._get_param_groups(lr, wd=weight_decay)
else:
allprm = [prm for prm in model.parameters() if prm.requires_grad]
p_groups = [{"params": allprm, "lr": lr}]
# Intialize the optimizer
if breed == "sgd":
optimizer = torch.optim.SGD(
p_groups, lr=lr, momentum=momentum, weight_decay=weight_decay
)
elif breed == "adagrad":
optimizer = torch.optim.Adagrad(p_groups, lr=lr, weight_decay=weight_decay)
elif breed == "adam":
optimizer = torch.optim.Adam(
p_groups, lr=lr, betas=betas, weight_decay=weight_decay
)
else:
raise ValueError("no such solver type %s" % breed)
logger.info(" -> solver type = %s" % breed)
# Load state from checkpoint
if optimizer_state is not None:
logger.info(" -> setting loaded optimizer state")
optimizer.load_state_dict(optimizer_state)
# Initialize the learning rate scheduler
if lr_policy == "multistep":
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=milestones,
gamma=gamma,
)
else:
raise ValueError("no such lr policy %s" % lr_policy)
# When loading from checkpoint, this will make sure that the
# lr is correctly set even after returning
for _ in range(last_epoch):
scheduler.step()
optimizer.zero_grad()
return optimizer, scheduler
enable_get_default_args(init_optimizer)

View File

@@ -0,0 +1,197 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import os
from typing import Any, Dict, Optional, Tuple
import torch.optim
from accelerate import Accelerator
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.tools import model_io
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
logger = logging.getLogger(__name__)
class OptimizerFactoryBase(ReplaceableBase):
def __call__(
self, model: ImplicitronModelBase, **kwargs
) -> Tuple[torch.optim.Optimizer, Any]:
"""
Initialize the optimizer and lr scheduler.
Args:
model: The model with optionally loaded weights.
Returns:
An optimizer module (optionally loaded from a checkpoint) and
a learning rate scheduler module (should be a subclass of torch.optim's
lr_scheduler._LRScheduler).
"""
raise NotImplementedError()
@registry.register
class ImplicitronOptimizerFactory(OptimizerFactoryBase):
"""
A factory that initializes the optimizer and lr scheduler.
Members:
betas: Beta parameters for the Adam optimizer.
breed: The type of optimizer to use. We currently support SGD, Adagrad
and Adam.
exponential_lr_step_size: With Exponential policy only,
lr = lr * gamma ** (epoch/step_size)
gamma: Multiplicative factor of learning rate decay.
lr: The value for the initial learning rate.
lr_policy: The policy to use for learning rate. We currently support
MultiStepLR and Exponential policies.
momentum: A momentum value (for SGD only).
multistep_lr_milestones: With MultiStepLR policy only: list of
increasing epoch indices at which the learning rate is modified.
momentum: Momentum factor for SGD optimizer.
resume: If True, attempt to load the last checkpoint from `exp_dir`
passed to __call__. Failure to do so will return a newly initialized
optimizer.
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
`resume_epoch` <= 0, then resume from the latest checkpoint.
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
"""
betas: Tuple[float, ...] = (0.9, 0.999)
breed: str = "Adam"
exponential_lr_step_size: int = 250
gamma: float = 0.1
lr: float = 0.0005
lr_policy: str = "MultiStepLR"
momentum: float = 0.9
multistep_lr_milestones: tuple = ()
resume: bool = False
resume_epoch: int = -1
weight_decay: float = 0.0
def __post_init__(self):
run_auto_creation(self)
def __call__(
self,
last_epoch: int,
model: ImplicitronModelBase,
accelerator: Optional[Accelerator] = None,
exp_dir: Optional[str] = None,
**kwargs,
) -> Tuple[torch.optim.Optimizer, Any]:
"""
Initialize the optimizer (optionally from a checkpoint) and the lr scheduluer.
Args:
last_epoch: If the model was loaded from checkpoint this will be the
number of the last epoch that was saved.
model: The model with optionally loaded weights.
accelerator: An optional Accelerator instance.
exp_dir: Root experiment directory.
Returns:
An optimizer module (optionally loaded from a checkpoint) and
a learning rate scheduler module (should be a subclass of torch.optim's
lr_scheduler._LRScheduler).
"""
# Get the parameters to optimize
if hasattr(model, "_get_param_groups"): # use the model function
# pyre-ignore[29]
p_groups = model._get_param_groups(self.lr, wd=self.weight_decay)
else:
allprm = [prm for prm in model.parameters() if prm.requires_grad]
p_groups = [{"params": allprm, "lr": self.lr}]
# Intialize the optimizer
if self.breed == "SGD":
optimizer = torch.optim.SGD(
p_groups,
lr=self.lr,
momentum=self.momentum,
weight_decay=self.weight_decay,
)
elif self.breed == "Adagrad":
optimizer = torch.optim.Adagrad(
p_groups, lr=self.lr, weight_decay=self.weight_decay
)
elif self.breed == "Adam":
optimizer = torch.optim.Adam(
p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay
)
else:
raise ValueError("no such solver type %s" % self.breed)
logger.info(" -> solver type = %s" % self.breed)
# Load state from checkpoint
optimizer_state = self._get_optimizer_state(exp_dir, accelerator)
if optimizer_state is not None:
logger.info(" -> setting loaded optimizer state")
optimizer.load_state_dict(optimizer_state)
# Initialize the learning rate scheduler
if self.lr_policy.casefold() == "MultiStepLR".casefold():
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=self.multistep_lr_milestones,
gamma=self.gamma,
)
elif self.lr_policy.casefold() == "Exponential".casefold():
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer,
lambda epoch: self.gamma ** (epoch / self.exponential_lr_step_size),
verbose=False,
)
else:
raise ValueError("no such lr policy %s" % self.lr_policy)
# When loading from checkpoint, this will make sure that the
# lr is correctly set even after returning.
for _ in range(last_epoch):
scheduler.step()
optimizer.zero_grad()
return optimizer, scheduler
def _get_optimizer_state(
self,
exp_dir: Optional[str],
accelerator: Optional[Accelerator] = None,
) -> Optional[Dict[str, Any]]:
"""
Load an optimizer state from a checkpoint.
"""
if exp_dir is None or not self.resume:
return None
if self.resume_epoch > 0:
save_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
else:
save_path = model_io.find_last_checkpoint(exp_dir)
optimizer_state = None
if save_path is not None:
logger.info(f"Found previous optimizer state {save_path}.")
logger.info(" -> resuming")
opt_path = model_io.get_optimizer_path(save_path)
if os.path.isfile(opt_path):
map_location = None
if accelerator is not None and not accelerator.is_local_main_process:
map_location = {
"cuda:%d" % 0: "cuda:%d" % accelerator.local_process_index
}
optimizer_state = torch.load(opt_path, map_location)
else:
optimizer_state = None
return optimizer_state

View File

@@ -0,0 +1,365 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import logging
import random
import time
from typing import Any, Optional
import numpy as np
import torch
from accelerate import Accelerator
from pytorch3d.implicitron.dataset.data_source import Task
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
from pytorch3d.implicitron.models.generic_model import EvaluationMode
from pytorch3d.implicitron.tools import model_io, vis_utils
from pytorch3d.implicitron.tools.config import (
registry,
ReplaceableBase,
run_auto_creation,
)
from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer.cameras import CamerasBase
from torch.utils.data import DataLoader
logger = logging.getLogger(__name__)
class TrainingLoopBase(ReplaceableBase):
def run(
self,
train_loader: DataLoader,
val_loader: Optional[DataLoader],
test_loader: Optional[DataLoader],
model: ImplicitronModelBase,
optimizer: torch.optim.Optimizer,
scheduler: Any,
**kwargs,
) -> None:
raise NotImplementedError()
@registry.register
class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
"""
Members:
eval_only: If True, only run evaluation using the test dataloader.
evaluator: An EvaluatorBase instance, used to evaluate training results.
max_epochs: Train for this many epochs. Note that if the model was
loaded from a checkpoint, we will restart training at the appropriate
epoch and run for (max_epochs - checkpoint_epoch) epochs.
seed: A random seed to ensure reproducibility.
store_checkpoints: If True, store model and optimizer state checkpoints.
store_checkpoints_purge: If >= 0, remove any checkpoints older or equal
to this many epochs.
test_interval: Evaluate on a test dataloader each `test_interval` epochs.
test_when_finished: If True, evaluate on a test dataloader when training
completes.
validation_interval: Validate each `validation_interval` epochs.
clip_grad: Optionally clip the gradient norms.
If set to a value <=0.0, no clipping
metric_print_interval: The batch interval at which the stats should be
logged.
visualize_interval: The batch interval at which the visualizations
should be plotted
"""
# Parameters of the outer training loop.
eval_only: bool = False
evaluator: EvaluatorBase
evaluator_class_type: str = "ImplicitronEvaluator"
max_epochs: int = 1000
seed: int = 0
store_checkpoints: bool = True
store_checkpoints_purge: int = 1
test_interval: int = -1
test_when_finished: bool = False
validation_interval: int = 1
# Parameters of a single training-validation step.
clip_grad: float = 0.0
metric_print_interval: int = 5
visualize_interval: int = 1000
def __post_init__(self):
run_auto_creation(self)
def run(
self,
*,
train_loader: DataLoader,
val_loader: Optional[DataLoader],
test_loader: Optional[DataLoader],
model: ImplicitronModelBase,
optimizer: torch.optim.Optimizer,
scheduler: Any,
accelerator: Optional[Accelerator],
all_train_cameras: Optional[CamerasBase],
device: torch.device,
exp_dir: str,
stats: Stats,
task: Task,
**kwargs,
):
"""
Entry point to run the training and validation loops
based on the specified config file.
"""
_seed_all_random_engines(self.seed)
start_epoch = stats.epoch + 1
assert scheduler.last_epoch == stats.epoch + 1
assert scheduler.last_epoch == start_epoch
# only run evaluation on the test dataloader
if self.eval_only:
if test_loader is not None:
self.evaluator.run(
all_train_cameras=all_train_cameras,
dataloader=test_loader,
device=device,
dump_to_json=True,
epoch=stats.epoch,
exp_dir=exp_dir,
model=model,
task=task,
)
return
else:
raise ValueError(
"Cannot evaluate and dump results to json, no test data provided."
)
# loop through epochs
for epoch in range(start_epoch, self.max_epochs):
# automatic new_epoch and plotting of stats at every epoch start
with stats:
# Make sure to re-seed random generators to ensure reproducibility
# even after restart.
_seed_all_random_engines(self.seed + epoch)
cur_lr = float(scheduler.get_last_lr()[-1])
logger.debug(f"scheduler lr = {cur_lr:1.2e}")
# train loop
self._training_or_validation_epoch(
accelerator=accelerator,
device=device,
epoch=epoch,
loader=train_loader,
model=model,
optimizer=optimizer,
stats=stats,
validation=False,
)
# val loop (optional)
if val_loader is not None and epoch % self.validation_interval == 0:
self._training_or_validation_epoch(
accelerator=accelerator,
device=device,
epoch=epoch,
loader=val_loader,
model=model,
optimizer=optimizer,
stats=stats,
validation=True,
)
# eval loop (optional)
if (
test_loader is not None
and self.test_interval > 0
and epoch % self.test_interval == 0
):
self.evaluator.run(
all_train_cameras=all_train_cameras,
device=device,
dataloader=test_loader,
model=model,
task=task,
)
assert stats.epoch == epoch, "inconsistent stats!"
self._checkpoint(accelerator, epoch, exp_dir, model, optimizer, stats)
scheduler.step()
new_lr = float(scheduler.get_last_lr()[-1])
if new_lr != cur_lr:
logger.info(f"LR change! {cur_lr} -> {new_lr}")
if self.test_when_finished:
if test_loader is not None:
self.evaluator.run(
all_train_cameras=all_train_cameras,
device=device,
dump_to_json=True,
epoch=stats.epoch,
exp_dir=exp_dir,
dataloader=test_loader,
model=model,
task=task,
)
else:
raise ValueError(
"Cannot evaluate and dump results to json, no test data provided."
)
def _training_or_validation_epoch(
self,
epoch: int,
loader: DataLoader,
model: ImplicitronModelBase,
optimizer: torch.optim.Optimizer,
stats: Stats,
validation: bool,
*,
accelerator: Optional[Accelerator],
bp_var: str = "objective",
device: torch.device,
**kwargs,
) -> None:
"""
This is the main loop for training and evaluation including:
model forward pass, loss computation, backward pass and visualization.
Args:
epoch: The index of the current epoch
loader: The dataloader to use for the loop
model: The model module optionally loaded from checkpoint
optimizer: The optimizer module optionally loaded from checkpoint
stats: The stats struct, also optionally loaded from checkpoint
validation: If true, run the loop with the model in eval mode
and skip the backward pass
accelerator: An optional Accelerator instance.
bp_var: The name of the key in the model output `preds` dict which
should be used as the loss for the backward pass.
device: The device on which to run the model.
"""
if validation:
model.eval()
trainmode = "val"
else:
model.train()
trainmode = "train"
t_start = time.time()
# get the visdom env name
visdom_env_imgs = stats.visdom_env + "_images_" + trainmode
viz = vis_utils.get_visdom_connection(
server=stats.visdom_server,
port=stats.visdom_port,
)
# Iterate through the batches
n_batches = len(loader)
for it, net_input in enumerate(loader):
last_iter = it == n_batches - 1
# move to gpu where possible (in place)
net_input = net_input.to(device)
# run the forward pass
if not validation:
optimizer.zero_grad()
preds = model(
**{**net_input, "evaluation_mode": EvaluationMode.TRAINING}
)
else:
with torch.no_grad():
preds = model(
**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}
)
# make sure we dont overwrite something
assert all(k not in preds for k in net_input.keys())
# merge everything into one big dict
preds.update(net_input)
# update the stats logger
stats.update(preds, time_start=t_start, stat_set=trainmode)
# pyre-ignore [16]
assert stats.it[trainmode] == it, "inconsistent stat iteration number!"
# print textual status update
if it % self.metric_print_interval == 0 or last_iter:
stats.print(stat_set=trainmode, max_it=n_batches)
# visualize results
if (
(accelerator is None or accelerator.is_local_main_process)
and self.visualize_interval > 0
and it % self.visualize_interval == 0
):
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
if hasattr(model, "visualize"):
# pyre-ignore [29]
model.visualize(
viz,
visdom_env_imgs,
preds,
prefix,
)
# optimizer step
if not validation:
loss = preds[bp_var]
assert torch.isfinite(loss).all(), "Non-finite loss!"
# backprop
if accelerator is None:
loss.backward()
else:
accelerator.backward(loss)
if self.clip_grad > 0.0:
# Optionally clip the gradient norms.
total_norm = torch.nn.utils.clip_grad_norm(
model.parameters(), self.clip_grad
)
if total_norm > self.clip_grad:
logger.debug(
f"Clipping gradient: {total_norm}"
+ f" with coef {self.clip_grad / float(total_norm)}."
)
optimizer.step()
def _checkpoint(
self,
accelerator: Optional[Accelerator],
epoch: int,
exp_dir: str,
model: ImplicitronModelBase,
optimizer: torch.optim.Optimizer,
stats: Stats,
):
"""
Save a model and its corresponding Stats object to a file, if
`self.store_checkpoints` is True. In addition, if
`self.store_checkpoints_purge` is True, remove any checkpoints older
than `self.store_checkpoints_purge` epochs old.
"""
if self.store_checkpoints and (
accelerator is None or accelerator.is_local_main_process
):
if self.store_checkpoints_purge > 0:
for prev_epoch in range(epoch - self.store_checkpoints_purge):
model_io.purge_epoch(exp_dir, prev_epoch)
outfile = model_io.get_checkpoint(exp_dir, epoch)
unwrapped_model = (
model if accelerator is None else accelerator.unwrap_model(model)
)
model_io.safe_save_model(
unwrapped_model, stats, outfile, optimizer=optimizer
)
def _seed_all_random_engines(seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)