mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-07-31 10:52:50 +08:00
Summary: Converts the directory specified to use the Ruff formatter in pyfmt ruff_dog If this diff causes merge conflicts when rebasing, please run `hg status -n -0 --change . -I '**/*.{py,pyi}' | xargs -0 arc pyfmt` on your diff, and amend any changes before rebasing onto latest. That should help reduce or eliminate any merge conflicts. allow-large-files Reviewed By: bottler Differential Revision: D66472063 fbshipit-source-id: 35841cb397e4f8e066e2159550d2f56b403b1bef
455 lines
16 KiB
Python
455 lines
16 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
# pyre-unsafe
|
|
|
|
import logging
|
|
import os
|
|
import time
|
|
from typing import Any, List, Optional
|
|
|
|
import torch
|
|
from accelerate import Accelerator
|
|
from pytorch3d.implicitron.evaluation.evaluator import EvaluatorBase
|
|
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
|
from pytorch3d.implicitron.models.generic_model import EvaluationMode
|
|
from pytorch3d.implicitron.tools import model_io, vis_utils
|
|
from pytorch3d.implicitron.tools.config import (
|
|
registry,
|
|
ReplaceableBase,
|
|
run_auto_creation,
|
|
)
|
|
from pytorch3d.implicitron.tools.stats import Stats
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
from .utils import seed_all_random_engines
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TrainingLoopBase(ReplaceableBase):
|
|
"""
|
|
Members:
|
|
evaluator: An EvaluatorBase instance, used to evaluate training results.
|
|
"""
|
|
|
|
# pyre-fixme[13]: Attribute `evaluator` is never initialized.
|
|
evaluator: Optional[EvaluatorBase]
|
|
evaluator_class_type: Optional[str] = "ImplicitronEvaluator"
|
|
|
|
def run(
|
|
self,
|
|
train_loader: DataLoader,
|
|
val_loader: Optional[DataLoader],
|
|
test_loader: Optional[DataLoader],
|
|
train_dataset: Dataset,
|
|
model: ImplicitronModelBase,
|
|
optimizer: torch.optim.Optimizer,
|
|
scheduler: Any,
|
|
**kwargs,
|
|
) -> None:
|
|
raise NotImplementedError()
|
|
|
|
def load_stats(
|
|
self,
|
|
log_vars: List[str],
|
|
exp_dir: str,
|
|
resume: bool = True,
|
|
resume_epoch: int = -1,
|
|
**kwargs,
|
|
) -> Stats:
|
|
raise NotImplementedError()
|
|
|
|
|
|
@registry.register
|
|
class ImplicitronTrainingLoop(TrainingLoopBase):
|
|
"""
|
|
Members:
|
|
eval_only: If True, only run evaluation using the test dataloader.
|
|
max_epochs: Train for this many epochs. Note that if the model was
|
|
loaded from a checkpoint, we will restart training at the appropriate
|
|
epoch and run for (max_epochs - checkpoint_epoch) epochs.
|
|
store_checkpoints: If True, store model and optimizer state checkpoints.
|
|
store_checkpoints_purge: If >= 0, remove any checkpoints older or equal
|
|
to this many epochs.
|
|
test_interval: Evaluate on a test dataloader each `test_interval` epochs.
|
|
test_when_finished: If True, evaluate on a test dataloader when training
|
|
completes.
|
|
validation_interval: Validate each `validation_interval` epochs.
|
|
clip_grad: Optionally clip the gradient norms.
|
|
If set to a value <=0.0, no clipping
|
|
metric_print_interval: The batch interval at which the stats should be
|
|
logged.
|
|
visualize_interval: The batch interval at which the visualizations
|
|
should be plotted
|
|
visdom_env: The name of the Visdom environment to use for plotting.
|
|
visdom_port: The Visdom port.
|
|
visdom_server: Address of the Visdom server.
|
|
"""
|
|
|
|
# Parameters of the outer training loop.
|
|
eval_only: bool = False
|
|
max_epochs: int = 1000
|
|
store_checkpoints: bool = True
|
|
store_checkpoints_purge: int = 1
|
|
test_interval: int = -1
|
|
test_when_finished: bool = False
|
|
validation_interval: int = 1
|
|
|
|
# Gradient clipping.
|
|
clip_grad: float = 0.0
|
|
|
|
# Visualization/logging parameters.
|
|
metric_print_interval: int = 5
|
|
visualize_interval: int = 1000
|
|
visdom_env: str = ""
|
|
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
|
|
visdom_server: str = "http://127.0.0.1"
|
|
|
|
def __post_init__(self):
|
|
run_auto_creation(self)
|
|
|
|
# pyre-fixme[14]: `run` overrides method defined in `TrainingLoopBase`
|
|
# inconsistently.
|
|
def run(
|
|
self,
|
|
*,
|
|
train_loader: DataLoader,
|
|
val_loader: Optional[DataLoader],
|
|
test_loader: Optional[DataLoader],
|
|
train_dataset: Dataset,
|
|
model: ImplicitronModelBase,
|
|
optimizer: torch.optim.Optimizer,
|
|
scheduler: Any,
|
|
accelerator: Optional[Accelerator],
|
|
device: torch.device,
|
|
exp_dir: str,
|
|
stats: Stats,
|
|
seed: int,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Entry point to run the training and validation loops
|
|
based on the specified config file.
|
|
"""
|
|
start_epoch = stats.epoch + 1
|
|
assert scheduler.last_epoch == stats.epoch + 1
|
|
assert scheduler.last_epoch == start_epoch
|
|
|
|
# only run evaluation on the test dataloader
|
|
if self.eval_only:
|
|
if test_loader is not None:
|
|
# pyre-fixme[16]: `Optional` has no attribute `run`.
|
|
self.evaluator.run(
|
|
dataloader=test_loader,
|
|
device=device,
|
|
dump_to_json=True,
|
|
epoch=stats.epoch,
|
|
exp_dir=exp_dir,
|
|
model=model,
|
|
)
|
|
return
|
|
else:
|
|
raise ValueError(
|
|
"Cannot evaluate and dump results to json, no test data provided."
|
|
)
|
|
|
|
# loop through epochs
|
|
for epoch in range(start_epoch, self.max_epochs):
|
|
# automatic new_epoch and plotting of stats at every epoch start
|
|
with stats:
|
|
# Make sure to re-seed random generators to ensure reproducibility
|
|
# even after restart.
|
|
seed_all_random_engines(seed + epoch)
|
|
|
|
cur_lr = float(scheduler.get_last_lr()[-1])
|
|
logger.debug(f"scheduler lr = {cur_lr:1.2e}")
|
|
|
|
# train loop
|
|
self._training_or_validation_epoch(
|
|
accelerator=accelerator,
|
|
device=device,
|
|
epoch=epoch,
|
|
loader=train_loader,
|
|
model=model,
|
|
optimizer=optimizer,
|
|
stats=stats,
|
|
validation=False,
|
|
)
|
|
|
|
# val loop (optional)
|
|
if val_loader is not None and epoch % self.validation_interval == 0:
|
|
self._training_or_validation_epoch(
|
|
accelerator=accelerator,
|
|
device=device,
|
|
epoch=epoch,
|
|
loader=val_loader,
|
|
model=model,
|
|
optimizer=optimizer,
|
|
stats=stats,
|
|
validation=True,
|
|
)
|
|
|
|
# eval loop (optional)
|
|
if (
|
|
test_loader is not None
|
|
and self.test_interval > 0
|
|
and epoch % self.test_interval == 0
|
|
):
|
|
self.evaluator.run(
|
|
device=device,
|
|
dataloader=test_loader,
|
|
model=model,
|
|
)
|
|
|
|
assert stats.epoch == epoch, "inconsistent stats!"
|
|
self._checkpoint(accelerator, epoch, exp_dir, model, optimizer, stats)
|
|
|
|
scheduler.step()
|
|
new_lr = float(scheduler.get_last_lr()[-1])
|
|
if new_lr != cur_lr:
|
|
logger.info(f"LR change! {cur_lr} -> {new_lr}")
|
|
|
|
if self.test_when_finished:
|
|
if test_loader is not None:
|
|
self.evaluator.run(
|
|
device=device,
|
|
dump_to_json=True,
|
|
epoch=stats.epoch,
|
|
exp_dir=exp_dir,
|
|
dataloader=test_loader,
|
|
model=model,
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
"Cannot evaluate and dump results to json, no test data provided."
|
|
)
|
|
|
|
def load_stats(
|
|
self,
|
|
log_vars: List[str],
|
|
exp_dir: str,
|
|
resume: bool = True,
|
|
resume_epoch: int = -1,
|
|
**kwargs,
|
|
) -> Stats:
|
|
"""
|
|
Load Stats that correspond to the model's log_vars and resume_epoch.
|
|
|
|
Args:
|
|
log_vars: A list of variable names to log. Should be a subset of the
|
|
`preds` returned by the forward function of the corresponding
|
|
ImplicitronModelBase instance.
|
|
exp_dir: Root experiment directory.
|
|
resume: If False, do not load stats from the checkpoint speci-
|
|
fied by resume and resume_epoch; instead, create a fresh stats object.
|
|
|
|
stats: The stats structure (optionally loaded from checkpoint)
|
|
"""
|
|
# Init the stats struct
|
|
visdom_env_charts = (
|
|
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
|
|
)
|
|
stats = Stats(
|
|
# log_vars should be a list, but OmegaConf might load them as ListConfig
|
|
list(log_vars),
|
|
plot_file=os.path.join(exp_dir, "train_stats.pdf"),
|
|
visdom_env=visdom_env_charts,
|
|
visdom_server=self.visdom_server,
|
|
visdom_port=self.visdom_port,
|
|
)
|
|
|
|
model_path = None
|
|
if resume:
|
|
if resume_epoch > 0:
|
|
model_path = model_io.get_checkpoint(exp_dir, resume_epoch)
|
|
if not os.path.isfile(model_path):
|
|
raise FileNotFoundError(
|
|
f"Cannot find stats from epoch {resume_epoch}."
|
|
)
|
|
else:
|
|
model_path = model_io.find_last_checkpoint(exp_dir)
|
|
|
|
if model_path is not None:
|
|
stats_path = model_io.get_stats_path(model_path)
|
|
stats_load = model_io.load_stats(stats_path)
|
|
|
|
# Determine if stats should be reset
|
|
if resume:
|
|
if stats_load is None:
|
|
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
|
|
last_epoch = model_io.parse_epoch_from_model_path(model_path)
|
|
logger.info(f"Estimated resume epoch = {last_epoch}")
|
|
|
|
# Reset the stats struct
|
|
for _ in range(last_epoch + 1):
|
|
stats.new_epoch()
|
|
assert last_epoch == stats.epoch
|
|
else:
|
|
logger.info(f"Found previous stats in {stats_path} -> resuming.")
|
|
stats = stats_load
|
|
|
|
# Update stats properties incase it was reset on load
|
|
stats.visdom_env = visdom_env_charts
|
|
stats.visdom_server = self.visdom_server
|
|
stats.visdom_port = self.visdom_port
|
|
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
|
|
stats.synchronize_logged_vars(log_vars)
|
|
else:
|
|
logger.info("Clearing stats")
|
|
|
|
return stats
|
|
|
|
def _training_or_validation_epoch(
|
|
self,
|
|
epoch: int,
|
|
loader: DataLoader,
|
|
model: ImplicitronModelBase,
|
|
optimizer: torch.optim.Optimizer,
|
|
stats: Stats,
|
|
validation: bool,
|
|
*,
|
|
accelerator: Optional[Accelerator],
|
|
bp_var: str = "objective",
|
|
device: torch.device,
|
|
**kwargs,
|
|
) -> None:
|
|
"""
|
|
This is the main loop for training and evaluation including:
|
|
model forward pass, loss computation, backward pass and visualization.
|
|
|
|
Args:
|
|
epoch: The index of the current epoch
|
|
loader: The dataloader to use for the loop
|
|
model: The model module optionally loaded from checkpoint
|
|
optimizer: The optimizer module optionally loaded from checkpoint
|
|
stats: The stats struct, also optionally loaded from checkpoint
|
|
validation: If true, run the loop with the model in eval mode
|
|
and skip the backward pass
|
|
accelerator: An optional Accelerator instance.
|
|
bp_var: The name of the key in the model output `preds` dict which
|
|
should be used as the loss for the backward pass.
|
|
device: The device on which to run the model.
|
|
"""
|
|
|
|
if validation:
|
|
model.eval()
|
|
trainmode = "val"
|
|
else:
|
|
model.train()
|
|
trainmode = "train"
|
|
|
|
t_start = time.time()
|
|
|
|
# get the visdom env name
|
|
visdom_env_imgs = stats.visdom_env + "_images_" + trainmode
|
|
viz = vis_utils.get_visdom_connection(
|
|
server=stats.visdom_server,
|
|
port=stats.visdom_port,
|
|
)
|
|
|
|
# Iterate through the batches
|
|
n_batches = len(loader)
|
|
for it, net_input in enumerate(loader):
|
|
last_iter = it == n_batches - 1
|
|
|
|
# move to gpu where possible (in place)
|
|
net_input = net_input.to(device)
|
|
|
|
# run the forward pass
|
|
if not validation:
|
|
optimizer.zero_grad()
|
|
preds = model(
|
|
**{**net_input, "evaluation_mode": EvaluationMode.TRAINING}
|
|
)
|
|
else:
|
|
with torch.no_grad():
|
|
preds = model(
|
|
**{**net_input, "evaluation_mode": EvaluationMode.EVALUATION}
|
|
)
|
|
|
|
# make sure we dont overwrite something
|
|
assert all(k not in preds for k in net_input.keys())
|
|
# merge everything into one big dict
|
|
preds.update(net_input)
|
|
|
|
# update the stats logger
|
|
stats.update(preds, time_start=t_start, stat_set=trainmode)
|
|
# pyre-ignore [16]
|
|
assert stats.it[trainmode] == it, "inconsistent stat iteration number!"
|
|
|
|
# print textual status update
|
|
if it % self.metric_print_interval == 0 or last_iter:
|
|
std_out = stats.get_status_string(stat_set=trainmode, max_it=n_batches)
|
|
logger.info(std_out)
|
|
|
|
# visualize results
|
|
if (
|
|
(accelerator is None or accelerator.is_local_main_process)
|
|
and self.visualize_interval > 0
|
|
and it % self.visualize_interval == 0
|
|
):
|
|
prefix = f"e{stats.epoch}_it{stats.it[trainmode]}"
|
|
if hasattr(model, "visualize"):
|
|
# pyre-fixme[29]: `Union[Tensor, Module]` is not a function.
|
|
model.visualize(
|
|
viz,
|
|
visdom_env_imgs,
|
|
preds,
|
|
prefix,
|
|
)
|
|
|
|
# optimizer step
|
|
if not validation:
|
|
loss = preds[bp_var]
|
|
assert torch.isfinite(loss).all(), "Non-finite loss!"
|
|
# backprop
|
|
if accelerator is None:
|
|
loss.backward()
|
|
else:
|
|
accelerator.backward(loss)
|
|
if self.clip_grad > 0.0:
|
|
# Optionally clip the gradient norms.
|
|
total_norm = torch.nn.utils.clip_grad_norm(
|
|
model.parameters(), self.clip_grad
|
|
)
|
|
if total_norm > self.clip_grad:
|
|
logger.debug(
|
|
f"Clipping gradient: {total_norm}"
|
|
+ f" with coef {self.clip_grad / float(total_norm)}."
|
|
)
|
|
|
|
optimizer.step()
|
|
|
|
def _checkpoint(
|
|
self,
|
|
accelerator: Optional[Accelerator],
|
|
epoch: int,
|
|
exp_dir: str,
|
|
model: ImplicitronModelBase,
|
|
optimizer: torch.optim.Optimizer,
|
|
stats: Stats,
|
|
):
|
|
"""
|
|
Save a model and its corresponding Stats object to a file, if
|
|
`self.store_checkpoints` is True. In addition, if
|
|
`self.store_checkpoints_purge` is True, remove any checkpoints older
|
|
than `self.store_checkpoints_purge` epochs old.
|
|
"""
|
|
if self.store_checkpoints and (
|
|
accelerator is None or accelerator.is_local_main_process
|
|
):
|
|
if self.store_checkpoints_purge > 0:
|
|
for prev_epoch in range(epoch - self.store_checkpoints_purge):
|
|
model_io.purge_epoch(exp_dir, prev_epoch)
|
|
outfile = model_io.get_checkpoint(exp_dir, epoch)
|
|
unwrapped_model = (
|
|
model if accelerator is None else accelerator.unwrap_model(model)
|
|
)
|
|
model_io.safe_save_model(
|
|
unwrapped_model, stats, outfile, optimizer=optimizer
|
|
)
|