mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-20 06:10:34 +08:00
Move load_stats to TrainingLoop
Summary: Stats are logically connected to the training loop, not to the model. Hence, moving to the training loop. Also removing resume_epoch from OptimizerFactory in favor of a single place - ModelFactory. This removes the need for config consistency checks etc. Reviewed By: kjchalup Differential Revision: D38313475 fbshipit-source-id: a1d188a63e28459df381ff98ad8acdcdb14887b7
This commit is contained in:
committed by
Facebook GitHub Bot
parent
760305e044
commit
c3f8dad55c
@@ -6,13 +6,13 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch.optim
|
||||
|
||||
from accelerate import Accelerator
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
||||
from pytorch3d.implicitron.tools import model_io
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
@@ -24,6 +24,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelFactoryBase(ReplaceableBase):
|
||||
|
||||
resume: bool = True # resume from the last checkpoint
|
||||
|
||||
def __call__(self, **kwargs) -> ImplicitronModelBase:
|
||||
"""
|
||||
Initialize the model (possibly from a previously saved state).
|
||||
@@ -45,27 +48,22 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
||||
A factory class that initializes an implicit rendering model.
|
||||
|
||||
Members:
|
||||
force_load: If True, throw a FileNotFoundError if `resume` is True but
|
||||
a model checkpoint cannot be found.
|
||||
model: An ImplicitronModelBase object.
|
||||
resume: If True, attempt to load the last checkpoint from `exp_dir`
|
||||
passed to __call__. Failure to do so will return a model with ini-
|
||||
tial weights unless `force_load` is True.
|
||||
tial weights unless `force_resume` is True.
|
||||
resume_epoch: If `resume` is True: Resume a model at this epoch, or if
|
||||
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
||||
visdom_env: The name of the Visdom environment to use for plotting.
|
||||
visdom_port: The Visdom port.
|
||||
visdom_server: Address of the Visdom server.
|
||||
force_resume: If True, throw a FileNotFoundError if `resume` is True but
|
||||
a model checkpoint cannot be found.
|
||||
|
||||
"""
|
||||
|
||||
force_load: bool = False
|
||||
model: ImplicitronModelBase
|
||||
model_class_type: str = "GenericModel"
|
||||
resume: bool = False
|
||||
resume: bool = True
|
||||
resume_epoch: int = -1
|
||||
visdom_env: str = ""
|
||||
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
|
||||
visdom_server: str = "http://127.0.0.1"
|
||||
force_resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
@@ -87,24 +85,27 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
||||
model: The model with optionally loaded weights from checkpoint
|
||||
|
||||
Raise:
|
||||
FileNotFoundError if `force_load` is True but checkpoint not found.
|
||||
FileNotFoundError if `force_resume` is True but checkpoint not found.
|
||||
"""
|
||||
# Determine the network outputs that should be logged
|
||||
if hasattr(self.model, "log_vars"):
|
||||
log_vars = list(self.model.log_vars) # pyre-ignore [6]
|
||||
log_vars = list(self.model.log_vars)
|
||||
else:
|
||||
log_vars = ["objective"]
|
||||
|
||||
# Retrieve the last checkpoint
|
||||
if self.resume_epoch > 0:
|
||||
# Resume from a certain epoch
|
||||
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
|
||||
if not os.path.isfile(model_path):
|
||||
raise ValueError(f"Cannot find model from epoch {self.resume_epoch}.")
|
||||
else:
|
||||
# Retrieve the last checkpoint
|
||||
model_path = model_io.find_last_checkpoint(exp_dir)
|
||||
|
||||
if model_path is not None:
|
||||
logger.info("found previous model %s" % model_path)
|
||||
if self.force_load or self.resume:
|
||||
logger.info(" -> resuming")
|
||||
logger.info(f"Found previous model {model_path}")
|
||||
if self.force_resume or self.resume:
|
||||
logger.info("Resuming.")
|
||||
|
||||
map_location = None
|
||||
if accelerator is not None and not accelerator.is_local_main_process:
|
||||
@@ -120,81 +121,13 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
||||
except RuntimeError as e:
|
||||
logger.error(e)
|
||||
logger.info(
|
||||
"Cant load state dict in strict mode! -> trying non-strict"
|
||||
"Cannot load state dict in strict mode! -> trying non-strict"
|
||||
)
|
||||
self.model.load_state_dict(model_state_dict, strict=False)
|
||||
self.model.log_vars = log_vars # pyre-ignore [16]
|
||||
self.model.log_vars = log_vars
|
||||
else:
|
||||
logger.info(" -> but not resuming -> starting from scratch")
|
||||
elif self.force_load:
|
||||
logger.info("Not resuming -> starting from scratch.")
|
||||
elif self.force_resume:
|
||||
raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!")
|
||||
|
||||
return self.model
|
||||
|
||||
def load_stats(
|
||||
self,
|
||||
log_vars: List[str],
|
||||
exp_dir: str,
|
||||
clear_stats: bool = False,
|
||||
**kwargs,
|
||||
) -> Stats:
|
||||
"""
|
||||
Load Stats that correspond to the model's log_vars.
|
||||
|
||||
Args:
|
||||
log_vars: A list of variable names to log. Should be a subset of the
|
||||
`preds` returned by the forward function of the corresponding
|
||||
ImplicitronModelBase instance.
|
||||
exp_dir: Root experiment directory.
|
||||
clear_stats: If True, do not load stats from the checkpoint speci-
|
||||
fied by self.resume and self.resume_epoch; instead, create a
|
||||
fresh stats object.
|
||||
|
||||
stats: The stats structure (optionally loaded from checkpoint)
|
||||
"""
|
||||
# Init the stats struct
|
||||
visdom_env_charts = (
|
||||
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
|
||||
)
|
||||
stats = Stats(
|
||||
# log_vars should be a list, but OmegaConf might load them as ListConfig
|
||||
list(log_vars),
|
||||
plot_file=os.path.join(exp_dir, "train_stats.pdf"),
|
||||
visdom_env=visdom_env_charts,
|
||||
verbose=False,
|
||||
visdom_server=self.visdom_server,
|
||||
visdom_port=self.visdom_port,
|
||||
)
|
||||
if self.resume_epoch > 0:
|
||||
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
|
||||
else:
|
||||
model_path = model_io.find_last_checkpoint(exp_dir)
|
||||
|
||||
if model_path is not None:
|
||||
stats_path = model_io.get_stats_path(model_path)
|
||||
stats_load = model_io.load_stats(stats_path)
|
||||
|
||||
# Determine if stats should be reset
|
||||
if not clear_stats:
|
||||
if stats_load is None:
|
||||
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
|
||||
last_epoch = model_io.parse_epoch_from_model_path(model_path)
|
||||
logger.info(f"Estimated resume epoch = {last_epoch}")
|
||||
|
||||
# Reset the stats struct
|
||||
for _ in range(last_epoch + 1):
|
||||
stats.new_epoch()
|
||||
assert last_epoch == stats.epoch
|
||||
else:
|
||||
stats = stats_load
|
||||
|
||||
# Update stats properties incase it was reset on load
|
||||
stats.visdom_env = visdom_env_charts
|
||||
stats.visdom_server = self.visdom_server
|
||||
stats.visdom_port = self.visdom_port
|
||||
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
|
||||
stats.synchronize_logged_vars(log_vars)
|
||||
else:
|
||||
logger.info(" -> clearing stats")
|
||||
|
||||
return stats
|
||||
|
||||
@@ -60,11 +60,6 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
multistep_lr_milestones: With MultiStepLR policy only: list of
|
||||
increasing epoch indices at which the learning rate is modified.
|
||||
momentum: Momentum factor for SGD optimizer.
|
||||
resume: If True, attempt to load the last checkpoint from `exp_dir`
|
||||
passed to __call__. Failure to do so will return a newly initialized
|
||||
optimizer.
|
||||
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
|
||||
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
||||
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
|
||||
"""
|
||||
|
||||
@@ -76,8 +71,6 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
lr_policy: str = "MultiStepLR"
|
||||
momentum: float = 0.9
|
||||
multistep_lr_milestones: tuple = ()
|
||||
resume: bool = False
|
||||
resume_epoch: int = -1
|
||||
weight_decay: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -89,6 +82,8 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
model: ImplicitronModelBase,
|
||||
accelerator: Optional[Accelerator] = None,
|
||||
exp_dir: Optional[str] = None,
|
||||
resume: bool = True,
|
||||
resume_epoch: int = -1,
|
||||
**kwargs,
|
||||
) -> Tuple[torch.optim.Optimizer, Any]:
|
||||
"""
|
||||
@@ -100,7 +95,10 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
model: The model with optionally loaded weights.
|
||||
accelerator: An optional Accelerator instance.
|
||||
exp_dir: Root experiment directory.
|
||||
|
||||
resume: If True, attempt to load optimizer checkpoint from exp_dir.
|
||||
Failure to do so will return a newly initialized optimizer.
|
||||
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
|
||||
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
||||
Returns:
|
||||
An optimizer module (optionally loaded from a checkpoint) and
|
||||
a learning rate scheduler module (should be a subclass of torch.optim's
|
||||
@@ -131,13 +129,18 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay
|
||||
)
|
||||
else:
|
||||
raise ValueError("no such solver type %s" % self.breed)
|
||||
logger.info(" -> solver type = %s" % self.breed)
|
||||
raise ValueError(f"No such solver type {self.breed}")
|
||||
logger.info(f"Solver type = {self.breed}")
|
||||
|
||||
# Load state from checkpoint
|
||||
optimizer_state = self._get_optimizer_state(exp_dir, accelerator)
|
||||
optimizer_state = self._get_optimizer_state(
|
||||
exp_dir,
|
||||
accelerator,
|
||||
resume_epoch=resume_epoch,
|
||||
resume=resume,
|
||||
)
|
||||
if optimizer_state is not None:
|
||||
logger.info(" -> setting loaded optimizer state")
|
||||
logger.info("Setting loaded optimizer state.")
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
# Initialize the learning rate scheduler
|
||||
@@ -169,20 +172,31 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
self,
|
||||
exp_dir: Optional[str],
|
||||
accelerator: Optional[Accelerator] = None,
|
||||
resume: bool = True,
|
||||
resume_epoch: int = -1,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Load an optimizer state from a checkpoint.
|
||||
|
||||
resume: If True, attempt to load the last checkpoint from `exp_dir`
|
||||
passed to __call__. Failure to do so will return a newly initialized
|
||||
optimizer.
|
||||
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
|
||||
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
||||
"""
|
||||
if exp_dir is None or not self.resume:
|
||||
if exp_dir is None or not resume:
|
||||
return None
|
||||
if self.resume_epoch > 0:
|
||||
save_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
|
||||
if resume_epoch > 0:
|
||||
save_path = model_io.get_checkpoint(exp_dir, resume_epoch)
|
||||
if not os.path.isfile(save_path):
|
||||
raise FileNotFoundError(
|
||||
f"Cannot find optimizer from epoch {resume_epoch}."
|
||||
)
|
||||
else:
|
||||
save_path = model_io.find_last_checkpoint(exp_dir)
|
||||
optimizer_state = None
|
||||
if save_path is not None:
|
||||
logger.info(f"Found previous optimizer state {save_path}.")
|
||||
logger.info(" -> resuming")
|
||||
logger.info(f"Found previous optimizer state {save_path} -> resuming.")
|
||||
opt_path = model_io.get_optimizer_path(save_path)
|
||||
|
||||
if os.path.isfile(opt_path):
|
||||
@@ -193,5 +207,5 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
||||
}
|
||||
optimizer_state = torch.load(opt_path, map_location)
|
||||
else:
|
||||
optimizer_state = None
|
||||
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
|
||||
return optimizer_state
|
||||
|
||||
@@ -5,8 +5,9 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
@@ -41,6 +42,16 @@ class TrainingLoopBase(ReplaceableBase):
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_stats(
|
||||
self,
|
||||
log_vars: List[str],
|
||||
exp_dir: str,
|
||||
resume: bool = True,
|
||||
resume_epoch: int = -1,
|
||||
**kwargs,
|
||||
) -> Stats:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
@@ -64,6 +75,9 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
logged.
|
||||
visualize_interval: The batch interval at which the visualizations
|
||||
should be plotted
|
||||
visdom_env: The name of the Visdom environment to use for plotting.
|
||||
visdom_port: The Visdom port.
|
||||
visdom_server: Address of the Visdom server.
|
||||
"""
|
||||
|
||||
# Parameters of the outer training loop.
|
||||
@@ -77,10 +91,15 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
test_when_finished: bool = False
|
||||
validation_interval: int = 1
|
||||
|
||||
# Parameters of a single training-validation step.
|
||||
# Gradient clipping.
|
||||
clip_grad: float = 0.0
|
||||
|
||||
# Visualization/logging parameters.
|
||||
metric_print_interval: int = 5
|
||||
visualize_interval: int = 1000
|
||||
visdom_env: str = ""
|
||||
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
|
||||
visdom_server: str = "http://127.0.0.1"
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
@@ -202,6 +221,81 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
"Cannot evaluate and dump results to json, no test data provided."
|
||||
)
|
||||
|
||||
def load_stats(
|
||||
self,
|
||||
log_vars: List[str],
|
||||
exp_dir: str,
|
||||
resume: bool = True,
|
||||
resume_epoch: int = -1,
|
||||
**kwargs,
|
||||
) -> Stats:
|
||||
"""
|
||||
Load Stats that correspond to the model's log_vars and resume_epoch.
|
||||
|
||||
Args:
|
||||
log_vars: A list of variable names to log. Should be a subset of the
|
||||
`preds` returned by the forward function of the corresponding
|
||||
ImplicitronModelBase instance.
|
||||
exp_dir: Root experiment directory.
|
||||
resume: If False, do not load stats from the checkpoint speci-
|
||||
fied by resume and resume_epoch; instead, create a fresh stats object.
|
||||
|
||||
stats: The stats structure (optionally loaded from checkpoint)
|
||||
"""
|
||||
# Init the stats struct
|
||||
visdom_env_charts = (
|
||||
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
|
||||
)
|
||||
stats = Stats(
|
||||
# log_vars should be a list, but OmegaConf might load them as ListConfig
|
||||
list(log_vars),
|
||||
visdom_env=visdom_env_charts,
|
||||
verbose=False,
|
||||
visdom_server=self.visdom_server,
|
||||
visdom_port=self.visdom_port,
|
||||
)
|
||||
|
||||
model_path = None
|
||||
if resume:
|
||||
if resume_epoch > 0:
|
||||
model_path = model_io.get_checkpoint(exp_dir, resume_epoch)
|
||||
if not os.path.isfile(model_path):
|
||||
raise FileNotFoundError(
|
||||
f"Cannot find stats from epoch {resume_epoch}."
|
||||
)
|
||||
else:
|
||||
model_path = model_io.find_last_checkpoint(exp_dir)
|
||||
|
||||
if model_path is not None:
|
||||
stats_path = model_io.get_stats_path(model_path)
|
||||
stats_load = model_io.load_stats(stats_path)
|
||||
|
||||
# Determine if stats should be reset
|
||||
if resume:
|
||||
if stats_load is None:
|
||||
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
|
||||
last_epoch = model_io.parse_epoch_from_model_path(model_path)
|
||||
logger.info(f"Estimated resume epoch = {last_epoch}")
|
||||
|
||||
# Reset the stats struct
|
||||
for _ in range(last_epoch + 1):
|
||||
stats.new_epoch()
|
||||
assert last_epoch == stats.epoch
|
||||
else:
|
||||
logger.info(f"Found previous stats in {stats_path} -> resuming.")
|
||||
stats = stats_load
|
||||
|
||||
# Update stats properties incase it was reset on load
|
||||
stats.visdom_env = visdom_env_charts
|
||||
stats.visdom_server = self.visdom_server
|
||||
stats.visdom_port = self.visdom_port
|
||||
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
|
||||
stats.synchronize_logged_vars(log_vars)
|
||||
else:
|
||||
logger.info("Clearing stats")
|
||||
|
||||
return stats
|
||||
|
||||
def _training_or_validation_epoch(
|
||||
self,
|
||||
epoch: int,
|
||||
|
||||
Reference in New Issue
Block a user