mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-12-23 15:50:39 +08:00
Move load_stats to TrainingLoop
Summary: Stats are logically connected to the training loop, not to the model. Hence, moving to the training loop. Also removing resume_epoch from OptimizerFactory in favor of a single place - ModelFactory. This removes the need for config consistency checks etc. Reviewed By: kjchalup Differential Revision: D38313475 fbshipit-source-id: a1d188a63e28459df381ff98ad8acdcdb14887b7
This commit is contained in:
committed by
Facebook GitHub Bot
parent
760305e044
commit
c3f8dad55c
@@ -6,13 +6,13 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch.optim
|
||||
|
||||
from accelerate import Accelerator
|
||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
||||
from pytorch3d.implicitron.tools import model_io
|
||||
from pytorch3d.implicitron.tools.config import (
|
||||
registry,
|
||||
ReplaceableBase,
|
||||
@@ -24,6 +24,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModelFactoryBase(ReplaceableBase):
|
||||
|
||||
resume: bool = True # resume from the last checkpoint
|
||||
|
||||
def __call__(self, **kwargs) -> ImplicitronModelBase:
|
||||
"""
|
||||
Initialize the model (possibly from a previously saved state).
|
||||
@@ -45,27 +48,22 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
||||
A factory class that initializes an implicit rendering model.
|
||||
|
||||
Members:
|
||||
force_load: If True, throw a FileNotFoundError if `resume` is True but
|
||||
a model checkpoint cannot be found.
|
||||
model: An ImplicitronModelBase object.
|
||||
resume: If True, attempt to load the last checkpoint from `exp_dir`
|
||||
passed to __call__. Failure to do so will return a model with ini-
|
||||
tial weights unless `force_load` is True.
|
||||
tial weights unless `force_resume` is True.
|
||||
resume_epoch: If `resume` is True: Resume a model at this epoch, or if
|
||||
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
||||
visdom_env: The name of the Visdom environment to use for plotting.
|
||||
visdom_port: The Visdom port.
|
||||
visdom_server: Address of the Visdom server.
|
||||
force_resume: If True, throw a FileNotFoundError if `resume` is True but
|
||||
a model checkpoint cannot be found.
|
||||
|
||||
"""
|
||||
|
||||
force_load: bool = False
|
||||
model: ImplicitronModelBase
|
||||
model_class_type: str = "GenericModel"
|
||||
resume: bool = False
|
||||
resume: bool = True
|
||||
resume_epoch: int = -1
|
||||
visdom_env: str = ""
|
||||
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
|
||||
visdom_server: str = "http://127.0.0.1"
|
||||
force_resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
run_auto_creation(self)
|
||||
@@ -87,24 +85,27 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
||||
model: The model with optionally loaded weights from checkpoint
|
||||
|
||||
Raise:
|
||||
FileNotFoundError if `force_load` is True but checkpoint not found.
|
||||
FileNotFoundError if `force_resume` is True but checkpoint not found.
|
||||
"""
|
||||
# Determine the network outputs that should be logged
|
||||
if hasattr(self.model, "log_vars"):
|
||||
log_vars = list(self.model.log_vars) # pyre-ignore [6]
|
||||
log_vars = list(self.model.log_vars)
|
||||
else:
|
||||
log_vars = ["objective"]
|
||||
|
||||
# Retrieve the last checkpoint
|
||||
if self.resume_epoch > 0:
|
||||
# Resume from a certain epoch
|
||||
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
|
||||
if not os.path.isfile(model_path):
|
||||
raise ValueError(f"Cannot find model from epoch {self.resume_epoch}.")
|
||||
else:
|
||||
# Retrieve the last checkpoint
|
||||
model_path = model_io.find_last_checkpoint(exp_dir)
|
||||
|
||||
if model_path is not None:
|
||||
logger.info("found previous model %s" % model_path)
|
||||
if self.force_load or self.resume:
|
||||
logger.info(" -> resuming")
|
||||
logger.info(f"Found previous model {model_path}")
|
||||
if self.force_resume or self.resume:
|
||||
logger.info("Resuming.")
|
||||
|
||||
map_location = None
|
||||
if accelerator is not None and not accelerator.is_local_main_process:
|
||||
@@ -120,81 +121,13 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
||||
except RuntimeError as e:
|
||||
logger.error(e)
|
||||
logger.info(
|
||||
"Cant load state dict in strict mode! -> trying non-strict"
|
||||
"Cannot load state dict in strict mode! -> trying non-strict"
|
||||
)
|
||||
self.model.load_state_dict(model_state_dict, strict=False)
|
||||
self.model.log_vars = log_vars # pyre-ignore [16]
|
||||
self.model.log_vars = log_vars
|
||||
else:
|
||||
logger.info(" -> but not resuming -> starting from scratch")
|
||||
elif self.force_load:
|
||||
logger.info("Not resuming -> starting from scratch.")
|
||||
elif self.force_resume:
|
||||
raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!")
|
||||
|
||||
return self.model
|
||||
|
||||
def load_stats(
|
||||
self,
|
||||
log_vars: List[str],
|
||||
exp_dir: str,
|
||||
clear_stats: bool = False,
|
||||
**kwargs,
|
||||
) -> Stats:
|
||||
"""
|
||||
Load Stats that correspond to the model's log_vars.
|
||||
|
||||
Args:
|
||||
log_vars: A list of variable names to log. Should be a subset of the
|
||||
`preds` returned by the forward function of the corresponding
|
||||
ImplicitronModelBase instance.
|
||||
exp_dir: Root experiment directory.
|
||||
clear_stats: If True, do not load stats from the checkpoint speci-
|
||||
fied by self.resume and self.resume_epoch; instead, create a
|
||||
fresh stats object.
|
||||
|
||||
stats: The stats structure (optionally loaded from checkpoint)
|
||||
"""
|
||||
# Init the stats struct
|
||||
visdom_env_charts = (
|
||||
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
|
||||
)
|
||||
stats = Stats(
|
||||
# log_vars should be a list, but OmegaConf might load them as ListConfig
|
||||
list(log_vars),
|
||||
plot_file=os.path.join(exp_dir, "train_stats.pdf"),
|
||||
visdom_env=visdom_env_charts,
|
||||
verbose=False,
|
||||
visdom_server=self.visdom_server,
|
||||
visdom_port=self.visdom_port,
|
||||
)
|
||||
if self.resume_epoch > 0:
|
||||
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
|
||||
else:
|
||||
model_path = model_io.find_last_checkpoint(exp_dir)
|
||||
|
||||
if model_path is not None:
|
||||
stats_path = model_io.get_stats_path(model_path)
|
||||
stats_load = model_io.load_stats(stats_path)
|
||||
|
||||
# Determine if stats should be reset
|
||||
if not clear_stats:
|
||||
if stats_load is None:
|
||||
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
|
||||
last_epoch = model_io.parse_epoch_from_model_path(model_path)
|
||||
logger.info(f"Estimated resume epoch = {last_epoch}")
|
||||
|
||||
# Reset the stats struct
|
||||
for _ in range(last_epoch + 1):
|
||||
stats.new_epoch()
|
||||
assert last_epoch == stats.epoch
|
||||
else:
|
||||
stats = stats_load
|
||||
|
||||
# Update stats properties incase it was reset on load
|
||||
stats.visdom_env = visdom_env_charts
|
||||
stats.visdom_server = self.visdom_server
|
||||
stats.visdom_port = self.visdom_port
|
||||
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
|
||||
stats.synchronize_logged_vars(log_vars)
|
||||
else:
|
||||
logger.info(" -> clearing stats")
|
||||
|
||||
return stats
|
||||
|
||||
Reference in New Issue
Block a user