Move load_stats to TrainingLoop

Summary:
Stats are logically connected to the training loop, not to the model. Hence, moving to the training loop.

Also removing resume_epoch from OptimizerFactory in favor of a single place - ModelFactory. This removes the need for config consistency checks etc.

Reviewed By: kjchalup

Differential Revision: D38313475

fbshipit-source-id: a1d188a63e28459df381ff98ad8acdcdb14887b7
This commit is contained in:
David Novotny
2022-08-02 15:40:53 -07:00
committed by Facebook GitHub Bot
parent 760305e044
commit c3f8dad55c
11 changed files with 259 additions and 189 deletions

View File

@@ -147,9 +147,6 @@ class Experiment(Configurable): # pyre-ignore: 13
run_auto_creation(self)
def run(self) -> None:
# Make sure the config settings are self-consistent.
self._check_config_consistent()
# Initialize the accelerator if desired.
if no_accelerate:
accelerator = None
@@ -176,9 +173,11 @@ class Experiment(Configurable): # pyre-ignore: 13
exp_dir=self.exp_dir,
)
stats = self.model_factory.load_stats(
exp_dir=self.exp_dir,
stats = self.training_loop.load_stats(
log_vars=model.log_vars,
exp_dir=self.exp_dir,
resume=self.model_factory.resume,
resume_epoch=self.model_factory.resume_epoch, # pyre-ignore [16]
)
start_epoch = stats.epoch + 1
@@ -190,6 +189,8 @@ class Experiment(Configurable): # pyre-ignore: 13
exp_dir=self.exp_dir,
last_epoch=start_epoch,
model=model,
resume=self.model_factory.resume,
resume_epoch=self.model_factory.resume_epoch,
)
# Wrap all modules in the distributed library
@@ -224,26 +225,6 @@ class Experiment(Configurable): # pyre-ignore: 13
seed=self.seed,
)
def _check_config_consistent(self) -> None:
if hasattr(self.optimizer_factory, "resume") and hasattr(
self.model_factory, "resume"
):
assert (
# pyre-ignore [16]
not self.optimizer_factory.resume
# pyre-ignore [16]
or self.model_factory.resume
), "Cannot resume the optimizer without resuming the model."
if hasattr(self.optimizer_factory, "resume_epoch") and hasattr(
self.model_factory, "resume_epoch"
):
assert (
# pyre-ignore [16]
self.optimizer_factory.resume_epoch
# pyre-ignore [16]
== self.model_factory.resume_epoch
), "Optimizer and model must resume from the same epoch."
def _setup_envvars_for_cluster() -> bool:
"""