mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-23 15:50:39 +08:00
Move load_stats to TrainingLoop
Summary: Stats are logically connected to the training loop, not to the model. Hence, moving to the training loop. Also removing resume_epoch from OptimizerFactory in favor of a single place - ModelFactory. This removes the need for config consistency checks etc. Reviewed By: kjchalup Differential Revision: D38313475 fbshipit-source-id: a1d188a63e28459df381ff98ad8acdcdb14887b7
This commit is contained in:
committed by
Facebook GitHub Bot
parent
760305e044
commit
c3f8dad55c
@@ -5,8 +5,9 @@
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
@@ -41,6 +42,16 @@ class TrainingLoopBase(ReplaceableBase):
|
||||
) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def load_stats(
|
||||
self,
|
||||
log_vars: List[str],
|
||||
exp_dir: str,
|
||||
resume: bool = True,
|
||||
resume_epoch: int = -1,
|
||||
**kwargs,
|
||||
) -> Stats:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@registry.register
|
||||
class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
@@ -64,6 +75,9 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
logged.
|
||||
visualize_interval: The batch interval at which the visualizations
|
||||
should be plotted
|
||||
visdom_env: The name of the Visdom environment to use for plotting.
|
||||
visdom_port: The Visdom port.
|
||||
visdom_server: Address of the Visdom server.
|
||||
"""
|
||||
|
||||
# Parameters of the outer training loop.
|
||||
@@ -77,10 +91,15 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
test_when_finished: bool = False
|
||||
validation_interval: int = 1
|
||||
|
||||
# Parameters of a single training-validation step.
|
||||
# Gradient clipping.
|
||||
clip_grad: float = 0.0
|
||||
|
||||
# Visualization/logging parameters.
|
||||
metric_print_interval: int = 5
|
||||
visualize_interval: int = 1000
|
||||
visdom_env: str = ""
|
||||
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
|
||||
visdom_server: str = "http://127.0.0.1"
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
@@ -202,6 +221,81 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||
"Cannot evaluate and dump results to json, no test data provided."
|
||||
)
|
||||
|
||||
def load_stats(
|
||||
self,
|
||||
log_vars: List[str],
|
||||
exp_dir: str,
|
||||
resume: bool = True,
|
||||
resume_epoch: int = -1,
|
||||
**kwargs,
|
||||
) -> Stats:
|
||||
"""
|
||||
Load Stats that correspond to the model's log_vars and resume_epoch.
|
||||
|
||||
Args:
|
||||
log_vars: A list of variable names to log. Should be a subset of the
|
||||
`preds` returned by the forward function of the corresponding
|
||||
ImplicitronModelBase instance.
|
||||
exp_dir: Root experiment directory.
|
||||
resume: If False, do not load stats from the checkpoint speci-
|
||||
fied by resume and resume_epoch; instead, create a fresh stats object.
|
||||
|
||||
stats: The stats structure (optionally loaded from checkpoint)
|
||||
"""
|
||||
# Init the stats struct
|
||||
visdom_env_charts = (
|
||||
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
|
||||
)
|
||||
stats = Stats(
|
||||
# log_vars should be a list, but OmegaConf might load them as ListConfig
|
||||
list(log_vars),
|
||||
visdom_env=visdom_env_charts,
|
||||
verbose=False,
|
||||
visdom_server=self.visdom_server,
|
||||
visdom_port=self.visdom_port,
|
||||
)
|
||||
|
||||
model_path = None
|
||||
if resume:
|
||||
if resume_epoch > 0:
|
||||
model_path = model_io.get_checkpoint(exp_dir, resume_epoch)
|
||||
if not os.path.isfile(model_path):
|
||||
raise FileNotFoundError(
|
||||
f"Cannot find stats from epoch {resume_epoch}."
|
||||
)
|
||||
else:
|
||||
model_path = model_io.find_last_checkpoint(exp_dir)
|
||||
|
||||
if model_path is not None:
|
||||
stats_path = model_io.get_stats_path(model_path)
|
||||
stats_load = model_io.load_stats(stats_path)
|
||||
|
||||
# Determine if stats should be reset
|
||||
if resume:
|
||||
if stats_load is None:
|
||||
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
|
||||
last_epoch = model_io.parse_epoch_from_model_path(model_path)
|
||||
logger.info(f"Estimated resume epoch = {last_epoch}")
|
||||
|
||||
# Reset the stats struct
|
||||
for _ in range(last_epoch + 1):
|
||||
stats.new_epoch()
|
||||
assert last_epoch == stats.epoch
|
||||
else:
|
||||
logger.info(f"Found previous stats in {stats_path} -> resuming.")
|
||||
stats = stats_load
|
||||
|
||||
# Update stats properties incase it was reset on load
|
||||
stats.visdom_env = visdom_env_charts
|
||||
stats.visdom_server = self.visdom_server
|
||||
stats.visdom_port = self.visdom_port
|
||||
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
|
||||
stats.synchronize_logged_vars(log_vars)
|
||||
else:
|
||||
logger.info("Clearing stats")
|
||||
|
||||
return stats
|
||||
|
||||
def _training_or_validation_epoch(
|
||||
self,
|
||||
epoch: int,
|
||||
|
||||
Reference in New Issue
Block a user