Move load_stats to TrainingLoop

Summary:
Stats are logically connected to the training loop, not to the model. Hence, moving to the training loop.

Also removing resume_epoch from OptimizerFactory in favor of a single place - ModelFactory. This removes the need for config consistency checks etc.

Reviewed By: kjchalup

Differential Revision: D38313475

fbshipit-source-id: a1d188a63e28459df381ff98ad8acdcdb14887b7
This commit is contained in:
David Novotny
2022-08-02 15:40:53 -07:00
committed by Facebook GitHub Bot
parent 760305e044
commit c3f8dad55c
11 changed files with 259 additions and 189 deletions

View File

@@ -5,8 +5,9 @@
# LICENSE file in the root directory of this source tree.
import logging
import os
import time
from typing import Any, Optional
from typing import Any, List, Optional
import torch
from accelerate import Accelerator
@@ -41,6 +42,16 @@ class TrainingLoopBase(ReplaceableBase):
) -> None:
raise NotImplementedError()
def load_stats(
self,
log_vars: List[str],
exp_dir: str,
resume: bool = True,
resume_epoch: int = -1,
**kwargs,
) -> Stats:
raise NotImplementedError()
@registry.register
class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
@@ -64,6 +75,9 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
logged.
visualize_interval: The batch interval at which the visualizations
should be plotted
visdom_env: The name of the Visdom environment to use for plotting.
visdom_port: The Visdom port.
visdom_server: Address of the Visdom server.
"""
# Parameters of the outer training loop.
@@ -77,10 +91,15 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
test_when_finished: bool = False
validation_interval: int = 1
# Parameters of a single training-validation step.
# Gradient clipping.
clip_grad: float = 0.0
# Visualization/logging parameters.
metric_print_interval: int = 5
visualize_interval: int = 1000
visdom_env: str = ""
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
visdom_server: str = "http://127.0.0.1"
def __post_init__(self):
run_auto_creation(self)
@@ -202,6 +221,81 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
"Cannot evaluate and dump results to json, no test data provided."
)
def load_stats(
self,
log_vars: List[str],
exp_dir: str,
resume: bool = True,
resume_epoch: int = -1,
**kwargs,
) -> Stats:
"""
Load Stats that correspond to the model's log_vars and resume_epoch.
Args:
log_vars: A list of variable names to log. Should be a subset of the
`preds` returned by the forward function of the corresponding
ImplicitronModelBase instance.
exp_dir: Root experiment directory.
resume: If False, do not load stats from the checkpoint speci-
fied by resume and resume_epoch; instead, create a fresh stats object.
stats: The stats structure (optionally loaded from checkpoint)
"""
# Init the stats struct
visdom_env_charts = (
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
)
stats = Stats(
# log_vars should be a list, but OmegaConf might load them as ListConfig
list(log_vars),
visdom_env=visdom_env_charts,
verbose=False,
visdom_server=self.visdom_server,
visdom_port=self.visdom_port,
)
model_path = None
if resume:
if resume_epoch > 0:
model_path = model_io.get_checkpoint(exp_dir, resume_epoch)
if not os.path.isfile(model_path):
raise FileNotFoundError(
f"Cannot find stats from epoch {resume_epoch}."
)
else:
model_path = model_io.find_last_checkpoint(exp_dir)
if model_path is not None:
stats_path = model_io.get_stats_path(model_path)
stats_load = model_io.load_stats(stats_path)
# Determine if stats should be reset
if resume:
if stats_load is None:
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
last_epoch = model_io.parse_epoch_from_model_path(model_path)
logger.info(f"Estimated resume epoch = {last_epoch}")
# Reset the stats struct
for _ in range(last_epoch + 1):
stats.new_epoch()
assert last_epoch == stats.epoch
else:
logger.info(f"Found previous stats in {stats_path} -> resuming.")
stats = stats_load
# Update stats properties incase it was reset on load
stats.visdom_env = visdom_env_charts
stats.visdom_server = self.visdom_server
stats.visdom_port = self.visdom_port
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
stats.synchronize_logged_vars(log_vars)
else:
logger.info("Clearing stats")
return stats
def _training_or_validation_epoch(
self,
epoch: int,