mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-22 07:10:34 +08:00
Move load_stats to TrainingLoop
Summary: Stats are logically connected to the training loop, not to the model. Hence, moving to the training loop. Also removing resume_epoch from OptimizerFactory in favor of a single place - ModelFactory. This removes the need for config consistency checks etc. Reviewed By: kjchalup Differential Revision: D38313475 fbshipit-source-id: a1d188a63e28459df381ff98ad8acdcdb14887b7
This commit is contained in:
committed by
Facebook GitHub Bot
parent
760305e044
commit
c3f8dad55c
@@ -147,9 +147,6 @@ class Experiment(Configurable): # pyre-ignore: 13
|
||||
run_auto_creation(self)
|
||||
|
||||
def run(self) -> None:
|
||||
# Make sure the config settings are self-consistent.
|
||||
self._check_config_consistent()
|
||||
|
||||
# Initialize the accelerator if desired.
|
||||
if no_accelerate:
|
||||
accelerator = None
|
||||
@@ -176,9 +173,11 @@ class Experiment(Configurable): # pyre-ignore: 13
|
||||
exp_dir=self.exp_dir,
|
||||
)
|
||||
|
||||
stats = self.model_factory.load_stats(
|
||||
exp_dir=self.exp_dir,
|
||||
stats = self.training_loop.load_stats(
|
||||
log_vars=model.log_vars,
|
||||
exp_dir=self.exp_dir,
|
||||
resume=self.model_factory.resume,
|
||||
resume_epoch=self.model_factory.resume_epoch, # pyre-ignore [16]
|
||||
)
|
||||
start_epoch = stats.epoch + 1
|
||||
|
||||
@@ -190,6 +189,8 @@ class Experiment(Configurable): # pyre-ignore: 13
|
||||
exp_dir=self.exp_dir,
|
||||
last_epoch=start_epoch,
|
||||
model=model,
|
||||
resume=self.model_factory.resume,
|
||||
resume_epoch=self.model_factory.resume_epoch,
|
||||
)
|
||||
|
||||
# Wrap all modules in the distributed library
|
||||
@@ -224,26 +225,6 @@ class Experiment(Configurable): # pyre-ignore: 13
|
||||
seed=self.seed,
|
||||
)
|
||||
|
||||
def _check_config_consistent(self) -> None:
|
||||
if hasattr(self.optimizer_factory, "resume") and hasattr(
|
||||
self.model_factory, "resume"
|
||||
):
|
||||
assert (
|
||||
# pyre-ignore [16]
|
||||
not self.optimizer_factory.resume
|
||||
# pyre-ignore [16]
|
||||
or self.model_factory.resume
|
||||
), "Cannot resume the optimizer without resuming the model."
|
||||
if hasattr(self.optimizer_factory, "resume_epoch") and hasattr(
|
||||
self.model_factory, "resume_epoch"
|
||||
):
|
||||
assert (
|
||||
# pyre-ignore [16]
|
||||
self.optimizer_factory.resume_epoch
|
||||
# pyre-ignore [16]
|
||||
== self.model_factory.resume_epoch
|
||||
), "Optimizer and model must resume from the same epoch."
|
||||
|
||||
|
||||
def _setup_envvars_for_cluster() -> bool:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user