mirror of
https://github.com/facebookresearch/pytorch3d.git
synced 2025-08-02 03:42:50 +08:00
Move load_stats to TrainingLoop
Summary: Stats are logically connected to the training loop, not to the model. Hence, moving to the training loop. Also removing resume_epoch from OptimizerFactory in favor of a single place - ModelFactory. This removes the need for config consistency checks etc. Reviewed By: kjchalup Differential Revision: D38313475 fbshipit-source-id: a1d188a63e28459df381ff98ad8acdcdb14887b7
This commit is contained in:
parent
760305e044
commit
c3f8dad55c
@ -3,6 +3,7 @@ defaults:
|
|||||||
- _self_
|
- _self_
|
||||||
exp_dir: ./data/exps/base/
|
exp_dir: ./data/exps/base/
|
||||||
training_loop_ImplicitronTrainingLoop_args:
|
training_loop_ImplicitronTrainingLoop_args:
|
||||||
|
visdom_port: 8097
|
||||||
visualize_interval: 0
|
visualize_interval: 0
|
||||||
max_epochs: 1000
|
max_epochs: 1000
|
||||||
data_source_ImplicitronDataSource_args:
|
data_source_ImplicitronDataSource_args:
|
||||||
@ -22,7 +23,6 @@ data_source_ImplicitronDataSource_args:
|
|||||||
mask_depths: false
|
mask_depths: false
|
||||||
mask_images: false
|
mask_images: false
|
||||||
model_factory_ImplicitronModelFactory_args:
|
model_factory_ImplicitronModelFactory_args:
|
||||||
visdom_port: 8097
|
|
||||||
model_GenericModel_args:
|
model_GenericModel_args:
|
||||||
loss_weights:
|
loss_weights:
|
||||||
loss_mask_bce: 1.0
|
loss_mask_bce: 1.0
|
||||||
|
@ -147,9 +147,6 @@ class Experiment(Configurable): # pyre-ignore: 13
|
|||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
|
|
||||||
def run(self) -> None:
|
def run(self) -> None:
|
||||||
# Make sure the config settings are self-consistent.
|
|
||||||
self._check_config_consistent()
|
|
||||||
|
|
||||||
# Initialize the accelerator if desired.
|
# Initialize the accelerator if desired.
|
||||||
if no_accelerate:
|
if no_accelerate:
|
||||||
accelerator = None
|
accelerator = None
|
||||||
@ -176,9 +173,11 @@ class Experiment(Configurable): # pyre-ignore: 13
|
|||||||
exp_dir=self.exp_dir,
|
exp_dir=self.exp_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
stats = self.model_factory.load_stats(
|
stats = self.training_loop.load_stats(
|
||||||
exp_dir=self.exp_dir,
|
|
||||||
log_vars=model.log_vars,
|
log_vars=model.log_vars,
|
||||||
|
exp_dir=self.exp_dir,
|
||||||
|
resume=self.model_factory.resume,
|
||||||
|
resume_epoch=self.model_factory.resume_epoch, # pyre-ignore [16]
|
||||||
)
|
)
|
||||||
start_epoch = stats.epoch + 1
|
start_epoch = stats.epoch + 1
|
||||||
|
|
||||||
@ -190,6 +189,8 @@ class Experiment(Configurable): # pyre-ignore: 13
|
|||||||
exp_dir=self.exp_dir,
|
exp_dir=self.exp_dir,
|
||||||
last_epoch=start_epoch,
|
last_epoch=start_epoch,
|
||||||
model=model,
|
model=model,
|
||||||
|
resume=self.model_factory.resume,
|
||||||
|
resume_epoch=self.model_factory.resume_epoch,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Wrap all modules in the distributed library
|
# Wrap all modules in the distributed library
|
||||||
@ -224,26 +225,6 @@ class Experiment(Configurable): # pyre-ignore: 13
|
|||||||
seed=self.seed,
|
seed=self.seed,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_config_consistent(self) -> None:
|
|
||||||
if hasattr(self.optimizer_factory, "resume") and hasattr(
|
|
||||||
self.model_factory, "resume"
|
|
||||||
):
|
|
||||||
assert (
|
|
||||||
# pyre-ignore [16]
|
|
||||||
not self.optimizer_factory.resume
|
|
||||||
# pyre-ignore [16]
|
|
||||||
or self.model_factory.resume
|
|
||||||
), "Cannot resume the optimizer without resuming the model."
|
|
||||||
if hasattr(self.optimizer_factory, "resume_epoch") and hasattr(
|
|
||||||
self.model_factory, "resume_epoch"
|
|
||||||
):
|
|
||||||
assert (
|
|
||||||
# pyre-ignore [16]
|
|
||||||
self.optimizer_factory.resume_epoch
|
|
||||||
# pyre-ignore [16]
|
|
||||||
== self.model_factory.resume_epoch
|
|
||||||
), "Optimizer and model must resume from the same epoch."
|
|
||||||
|
|
||||||
|
|
||||||
def _setup_envvars_for_cluster() -> bool:
|
def _setup_envvars_for_cluster() -> bool:
|
||||||
"""
|
"""
|
||||||
|
@ -6,13 +6,13 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch.optim
|
import torch.optim
|
||||||
|
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
from pytorch3d.implicitron.models.base_model import ImplicitronModelBase
|
||||||
from pytorch3d.implicitron.tools import model_io, vis_utils
|
from pytorch3d.implicitron.tools import model_io
|
||||||
from pytorch3d.implicitron.tools.config import (
|
from pytorch3d.implicitron.tools.config import (
|
||||||
registry,
|
registry,
|
||||||
ReplaceableBase,
|
ReplaceableBase,
|
||||||
@ -24,6 +24,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ModelFactoryBase(ReplaceableBase):
|
class ModelFactoryBase(ReplaceableBase):
|
||||||
|
|
||||||
|
resume: bool = True # resume from the last checkpoint
|
||||||
|
|
||||||
def __call__(self, **kwargs) -> ImplicitronModelBase:
|
def __call__(self, **kwargs) -> ImplicitronModelBase:
|
||||||
"""
|
"""
|
||||||
Initialize the model (possibly from a previously saved state).
|
Initialize the model (possibly from a previously saved state).
|
||||||
@ -45,27 +48,22 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
|||||||
A factory class that initializes an implicit rendering model.
|
A factory class that initializes an implicit rendering model.
|
||||||
|
|
||||||
Members:
|
Members:
|
||||||
force_load: If True, throw a FileNotFoundError if `resume` is True but
|
|
||||||
a model checkpoint cannot be found.
|
|
||||||
model: An ImplicitronModelBase object.
|
model: An ImplicitronModelBase object.
|
||||||
resume: If True, attempt to load the last checkpoint from `exp_dir`
|
resume: If True, attempt to load the last checkpoint from `exp_dir`
|
||||||
passed to __call__. Failure to do so will return a model with ini-
|
passed to __call__. Failure to do so will return a model with ini-
|
||||||
tial weights unless `force_load` is True.
|
tial weights unless `force_resume` is True.
|
||||||
resume_epoch: If `resume` is True: Resume a model at this epoch, or if
|
resume_epoch: If `resume` is True: Resume a model at this epoch, or if
|
||||||
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
||||||
visdom_env: The name of the Visdom environment to use for plotting.
|
force_resume: If True, throw a FileNotFoundError if `resume` is True but
|
||||||
visdom_port: The Visdom port.
|
a model checkpoint cannot be found.
|
||||||
visdom_server: Address of the Visdom server.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
force_load: bool = False
|
|
||||||
model: ImplicitronModelBase
|
model: ImplicitronModelBase
|
||||||
model_class_type: str = "GenericModel"
|
model_class_type: str = "GenericModel"
|
||||||
resume: bool = False
|
resume: bool = True
|
||||||
resume_epoch: int = -1
|
resume_epoch: int = -1
|
||||||
visdom_env: str = ""
|
force_resume: bool = False
|
||||||
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
|
|
||||||
visdom_server: str = "http://127.0.0.1"
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
@ -87,24 +85,27 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
|||||||
model: The model with optionally loaded weights from checkpoint
|
model: The model with optionally loaded weights from checkpoint
|
||||||
|
|
||||||
Raise:
|
Raise:
|
||||||
FileNotFoundError if `force_load` is True but checkpoint not found.
|
FileNotFoundError if `force_resume` is True but checkpoint not found.
|
||||||
"""
|
"""
|
||||||
# Determine the network outputs that should be logged
|
# Determine the network outputs that should be logged
|
||||||
if hasattr(self.model, "log_vars"):
|
if hasattr(self.model, "log_vars"):
|
||||||
log_vars = list(self.model.log_vars) # pyre-ignore [6]
|
log_vars = list(self.model.log_vars)
|
||||||
else:
|
else:
|
||||||
log_vars = ["objective"]
|
log_vars = ["objective"]
|
||||||
|
|
||||||
# Retrieve the last checkpoint
|
|
||||||
if self.resume_epoch > 0:
|
if self.resume_epoch > 0:
|
||||||
|
# Resume from a certain epoch
|
||||||
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
|
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
|
||||||
|
if not os.path.isfile(model_path):
|
||||||
|
raise ValueError(f"Cannot find model from epoch {self.resume_epoch}.")
|
||||||
else:
|
else:
|
||||||
|
# Retrieve the last checkpoint
|
||||||
model_path = model_io.find_last_checkpoint(exp_dir)
|
model_path = model_io.find_last_checkpoint(exp_dir)
|
||||||
|
|
||||||
if model_path is not None:
|
if model_path is not None:
|
||||||
logger.info("found previous model %s" % model_path)
|
logger.info(f"Found previous model {model_path}")
|
||||||
if self.force_load or self.resume:
|
if self.force_resume or self.resume:
|
||||||
logger.info(" -> resuming")
|
logger.info("Resuming.")
|
||||||
|
|
||||||
map_location = None
|
map_location = None
|
||||||
if accelerator is not None and not accelerator.is_local_main_process:
|
if accelerator is not None and not accelerator.is_local_main_process:
|
||||||
@ -120,81 +121,13 @@ class ImplicitronModelFactory(ModelFactoryBase): # pyre-ignore [13]
|
|||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Cant load state dict in strict mode! -> trying non-strict"
|
"Cannot load state dict in strict mode! -> trying non-strict"
|
||||||
)
|
)
|
||||||
self.model.load_state_dict(model_state_dict, strict=False)
|
self.model.load_state_dict(model_state_dict, strict=False)
|
||||||
self.model.log_vars = log_vars # pyre-ignore [16]
|
self.model.log_vars = log_vars
|
||||||
else:
|
else:
|
||||||
logger.info(" -> but not resuming -> starting from scratch")
|
logger.info("Not resuming -> starting from scratch.")
|
||||||
elif self.force_load:
|
elif self.force_resume:
|
||||||
raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!")
|
raise FileNotFoundError(f"Cannot find a checkpoint in {exp_dir}!")
|
||||||
|
|
||||||
return self.model
|
return self.model
|
||||||
|
|
||||||
def load_stats(
|
|
||||||
self,
|
|
||||||
log_vars: List[str],
|
|
||||||
exp_dir: str,
|
|
||||||
clear_stats: bool = False,
|
|
||||||
**kwargs,
|
|
||||||
) -> Stats:
|
|
||||||
"""
|
|
||||||
Load Stats that correspond to the model's log_vars.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
log_vars: A list of variable names to log. Should be a subset of the
|
|
||||||
`preds` returned by the forward function of the corresponding
|
|
||||||
ImplicitronModelBase instance.
|
|
||||||
exp_dir: Root experiment directory.
|
|
||||||
clear_stats: If True, do not load stats from the checkpoint speci-
|
|
||||||
fied by self.resume and self.resume_epoch; instead, create a
|
|
||||||
fresh stats object.
|
|
||||||
|
|
||||||
stats: The stats structure (optionally loaded from checkpoint)
|
|
||||||
"""
|
|
||||||
# Init the stats struct
|
|
||||||
visdom_env_charts = (
|
|
||||||
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
|
|
||||||
)
|
|
||||||
stats = Stats(
|
|
||||||
# log_vars should be a list, but OmegaConf might load them as ListConfig
|
|
||||||
list(log_vars),
|
|
||||||
plot_file=os.path.join(exp_dir, "train_stats.pdf"),
|
|
||||||
visdom_env=visdom_env_charts,
|
|
||||||
verbose=False,
|
|
||||||
visdom_server=self.visdom_server,
|
|
||||||
visdom_port=self.visdom_port,
|
|
||||||
)
|
|
||||||
if self.resume_epoch > 0:
|
|
||||||
model_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
|
|
||||||
else:
|
|
||||||
model_path = model_io.find_last_checkpoint(exp_dir)
|
|
||||||
|
|
||||||
if model_path is not None:
|
|
||||||
stats_path = model_io.get_stats_path(model_path)
|
|
||||||
stats_load = model_io.load_stats(stats_path)
|
|
||||||
|
|
||||||
# Determine if stats should be reset
|
|
||||||
if not clear_stats:
|
|
||||||
if stats_load is None:
|
|
||||||
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
|
|
||||||
last_epoch = model_io.parse_epoch_from_model_path(model_path)
|
|
||||||
logger.info(f"Estimated resume epoch = {last_epoch}")
|
|
||||||
|
|
||||||
# Reset the stats struct
|
|
||||||
for _ in range(last_epoch + 1):
|
|
||||||
stats.new_epoch()
|
|
||||||
assert last_epoch == stats.epoch
|
|
||||||
else:
|
|
||||||
stats = stats_load
|
|
||||||
|
|
||||||
# Update stats properties incase it was reset on load
|
|
||||||
stats.visdom_env = visdom_env_charts
|
|
||||||
stats.visdom_server = self.visdom_server
|
|
||||||
stats.visdom_port = self.visdom_port
|
|
||||||
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
|
|
||||||
stats.synchronize_logged_vars(log_vars)
|
|
||||||
else:
|
|
||||||
logger.info(" -> clearing stats")
|
|
||||||
|
|
||||||
return stats
|
|
||||||
|
@ -60,11 +60,6 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
multistep_lr_milestones: With MultiStepLR policy only: list of
|
multistep_lr_milestones: With MultiStepLR policy only: list of
|
||||||
increasing epoch indices at which the learning rate is modified.
|
increasing epoch indices at which the learning rate is modified.
|
||||||
momentum: Momentum factor for SGD optimizer.
|
momentum: Momentum factor for SGD optimizer.
|
||||||
resume: If True, attempt to load the last checkpoint from `exp_dir`
|
|
||||||
passed to __call__. Failure to do so will return a newly initialized
|
|
||||||
optimizer.
|
|
||||||
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
|
|
||||||
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
|
||||||
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
|
weight_decay: The optimizer weight_decay (L2 penalty on model weights).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -76,8 +71,6 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
lr_policy: str = "MultiStepLR"
|
lr_policy: str = "MultiStepLR"
|
||||||
momentum: float = 0.9
|
momentum: float = 0.9
|
||||||
multistep_lr_milestones: tuple = ()
|
multistep_lr_milestones: tuple = ()
|
||||||
resume: bool = False
|
|
||||||
resume_epoch: int = -1
|
|
||||||
weight_decay: float = 0.0
|
weight_decay: float = 0.0
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@ -89,6 +82,8 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
model: ImplicitronModelBase,
|
model: ImplicitronModelBase,
|
||||||
accelerator: Optional[Accelerator] = None,
|
accelerator: Optional[Accelerator] = None,
|
||||||
exp_dir: Optional[str] = None,
|
exp_dir: Optional[str] = None,
|
||||||
|
resume: bool = True,
|
||||||
|
resume_epoch: int = -1,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[torch.optim.Optimizer, Any]:
|
) -> Tuple[torch.optim.Optimizer, Any]:
|
||||||
"""
|
"""
|
||||||
@ -100,7 +95,10 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
model: The model with optionally loaded weights.
|
model: The model with optionally loaded weights.
|
||||||
accelerator: An optional Accelerator instance.
|
accelerator: An optional Accelerator instance.
|
||||||
exp_dir: Root experiment directory.
|
exp_dir: Root experiment directory.
|
||||||
|
resume: If True, attempt to load optimizer checkpoint from exp_dir.
|
||||||
|
Failure to do so will return a newly initialized optimizer.
|
||||||
|
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
|
||||||
|
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
||||||
Returns:
|
Returns:
|
||||||
An optimizer module (optionally loaded from a checkpoint) and
|
An optimizer module (optionally loaded from a checkpoint) and
|
||||||
a learning rate scheduler module (should be a subclass of torch.optim's
|
a learning rate scheduler module (should be a subclass of torch.optim's
|
||||||
@ -131,13 +129,18 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay
|
p_groups, lr=self.lr, betas=self.betas, weight_decay=self.weight_decay
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("no such solver type %s" % self.breed)
|
raise ValueError(f"No such solver type {self.breed}")
|
||||||
logger.info(" -> solver type = %s" % self.breed)
|
logger.info(f"Solver type = {self.breed}")
|
||||||
|
|
||||||
# Load state from checkpoint
|
# Load state from checkpoint
|
||||||
optimizer_state = self._get_optimizer_state(exp_dir, accelerator)
|
optimizer_state = self._get_optimizer_state(
|
||||||
|
exp_dir,
|
||||||
|
accelerator,
|
||||||
|
resume_epoch=resume_epoch,
|
||||||
|
resume=resume,
|
||||||
|
)
|
||||||
if optimizer_state is not None:
|
if optimizer_state is not None:
|
||||||
logger.info(" -> setting loaded optimizer state")
|
logger.info("Setting loaded optimizer state.")
|
||||||
optimizer.load_state_dict(optimizer_state)
|
optimizer.load_state_dict(optimizer_state)
|
||||||
|
|
||||||
# Initialize the learning rate scheduler
|
# Initialize the learning rate scheduler
|
||||||
@ -169,20 +172,31 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
self,
|
self,
|
||||||
exp_dir: Optional[str],
|
exp_dir: Optional[str],
|
||||||
accelerator: Optional[Accelerator] = None,
|
accelerator: Optional[Accelerator] = None,
|
||||||
|
resume: bool = True,
|
||||||
|
resume_epoch: int = -1,
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Load an optimizer state from a checkpoint.
|
Load an optimizer state from a checkpoint.
|
||||||
|
|
||||||
|
resume: If True, attempt to load the last checkpoint from `exp_dir`
|
||||||
|
passed to __call__. Failure to do so will return a newly initialized
|
||||||
|
optimizer.
|
||||||
|
resume_epoch: If `resume` is True: Resume optimizer at this epoch. If
|
||||||
|
`resume_epoch` <= 0, then resume from the latest checkpoint.
|
||||||
"""
|
"""
|
||||||
if exp_dir is None or not self.resume:
|
if exp_dir is None or not resume:
|
||||||
return None
|
return None
|
||||||
if self.resume_epoch > 0:
|
if resume_epoch > 0:
|
||||||
save_path = model_io.get_checkpoint(exp_dir, self.resume_epoch)
|
save_path = model_io.get_checkpoint(exp_dir, resume_epoch)
|
||||||
|
if not os.path.isfile(save_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Cannot find optimizer from epoch {resume_epoch}."
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
save_path = model_io.find_last_checkpoint(exp_dir)
|
save_path = model_io.find_last_checkpoint(exp_dir)
|
||||||
optimizer_state = None
|
optimizer_state = None
|
||||||
if save_path is not None:
|
if save_path is not None:
|
||||||
logger.info(f"Found previous optimizer state {save_path}.")
|
logger.info(f"Found previous optimizer state {save_path} -> resuming.")
|
||||||
logger.info(" -> resuming")
|
|
||||||
opt_path = model_io.get_optimizer_path(save_path)
|
opt_path = model_io.get_optimizer_path(save_path)
|
||||||
|
|
||||||
if os.path.isfile(opt_path):
|
if os.path.isfile(opt_path):
|
||||||
@ -193,5 +207,5 @@ class ImplicitronOptimizerFactory(OptimizerFactoryBase):
|
|||||||
}
|
}
|
||||||
optimizer_state = torch.load(opt_path, map_location)
|
optimizer_state = torch.load(opt_path, map_location)
|
||||||
else:
|
else:
|
||||||
optimizer_state = None
|
raise FileNotFoundError(f"Optimizer state {opt_path} does not exist.")
|
||||||
return optimizer_state
|
return optimizer_state
|
||||||
|
@ -5,8 +5,9 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import Any, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
@ -41,6 +42,16 @@ class TrainingLoopBase(ReplaceableBase):
|
|||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def load_stats(
|
||||||
|
self,
|
||||||
|
log_vars: List[str],
|
||||||
|
exp_dir: str,
|
||||||
|
resume: bool = True,
|
||||||
|
resume_epoch: int = -1,
|
||||||
|
**kwargs,
|
||||||
|
) -> Stats:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
@registry.register
|
||||||
class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
||||||
@ -64,6 +75,9 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
|||||||
logged.
|
logged.
|
||||||
visualize_interval: The batch interval at which the visualizations
|
visualize_interval: The batch interval at which the visualizations
|
||||||
should be plotted
|
should be plotted
|
||||||
|
visdom_env: The name of the Visdom environment to use for plotting.
|
||||||
|
visdom_port: The Visdom port.
|
||||||
|
visdom_server: Address of the Visdom server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Parameters of the outer training loop.
|
# Parameters of the outer training loop.
|
||||||
@ -77,10 +91,15 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
|||||||
test_when_finished: bool = False
|
test_when_finished: bool = False
|
||||||
validation_interval: int = 1
|
validation_interval: int = 1
|
||||||
|
|
||||||
# Parameters of a single training-validation step.
|
# Gradient clipping.
|
||||||
clip_grad: float = 0.0
|
clip_grad: float = 0.0
|
||||||
|
|
||||||
|
# Visualization/logging parameters.
|
||||||
metric_print_interval: int = 5
|
metric_print_interval: int = 5
|
||||||
visualize_interval: int = 1000
|
visualize_interval: int = 1000
|
||||||
|
visdom_env: str = ""
|
||||||
|
visdom_port: int = int(os.environ.get("VISDOM_PORT", 8097))
|
||||||
|
visdom_server: str = "http://127.0.0.1"
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
run_auto_creation(self)
|
run_auto_creation(self)
|
||||||
@ -202,6 +221,81 @@ class ImplicitronTrainingLoop(TrainingLoopBase): # pyre-ignore [13]
|
|||||||
"Cannot evaluate and dump results to json, no test data provided."
|
"Cannot evaluate and dump results to json, no test data provided."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def load_stats(
|
||||||
|
self,
|
||||||
|
log_vars: List[str],
|
||||||
|
exp_dir: str,
|
||||||
|
resume: bool = True,
|
||||||
|
resume_epoch: int = -1,
|
||||||
|
**kwargs,
|
||||||
|
) -> Stats:
|
||||||
|
"""
|
||||||
|
Load Stats that correspond to the model's log_vars and resume_epoch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_vars: A list of variable names to log. Should be a subset of the
|
||||||
|
`preds` returned by the forward function of the corresponding
|
||||||
|
ImplicitronModelBase instance.
|
||||||
|
exp_dir: Root experiment directory.
|
||||||
|
resume: If False, do not load stats from the checkpoint speci-
|
||||||
|
fied by resume and resume_epoch; instead, create a fresh stats object.
|
||||||
|
|
||||||
|
stats: The stats structure (optionally loaded from checkpoint)
|
||||||
|
"""
|
||||||
|
# Init the stats struct
|
||||||
|
visdom_env_charts = (
|
||||||
|
vis_utils.get_visdom_env(self.visdom_env, exp_dir) + "_charts"
|
||||||
|
)
|
||||||
|
stats = Stats(
|
||||||
|
# log_vars should be a list, but OmegaConf might load them as ListConfig
|
||||||
|
list(log_vars),
|
||||||
|
visdom_env=visdom_env_charts,
|
||||||
|
verbose=False,
|
||||||
|
visdom_server=self.visdom_server,
|
||||||
|
visdom_port=self.visdom_port,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_path = None
|
||||||
|
if resume:
|
||||||
|
if resume_epoch > 0:
|
||||||
|
model_path = model_io.get_checkpoint(exp_dir, resume_epoch)
|
||||||
|
if not os.path.isfile(model_path):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Cannot find stats from epoch {resume_epoch}."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_path = model_io.find_last_checkpoint(exp_dir)
|
||||||
|
|
||||||
|
if model_path is not None:
|
||||||
|
stats_path = model_io.get_stats_path(model_path)
|
||||||
|
stats_load = model_io.load_stats(stats_path)
|
||||||
|
|
||||||
|
# Determine if stats should be reset
|
||||||
|
if resume:
|
||||||
|
if stats_load is None:
|
||||||
|
logger.warning("\n\n\n\nCORRUPT STATS -> clearing stats\n\n\n\n")
|
||||||
|
last_epoch = model_io.parse_epoch_from_model_path(model_path)
|
||||||
|
logger.info(f"Estimated resume epoch = {last_epoch}")
|
||||||
|
|
||||||
|
# Reset the stats struct
|
||||||
|
for _ in range(last_epoch + 1):
|
||||||
|
stats.new_epoch()
|
||||||
|
assert last_epoch == stats.epoch
|
||||||
|
else:
|
||||||
|
logger.info(f"Found previous stats in {stats_path} -> resuming.")
|
||||||
|
stats = stats_load
|
||||||
|
|
||||||
|
# Update stats properties incase it was reset on load
|
||||||
|
stats.visdom_env = visdom_env_charts
|
||||||
|
stats.visdom_server = self.visdom_server
|
||||||
|
stats.visdom_port = self.visdom_port
|
||||||
|
stats.plot_file = os.path.join(exp_dir, "train_stats.pdf")
|
||||||
|
stats.synchronize_logged_vars(log_vars)
|
||||||
|
else:
|
||||||
|
logger.info("Clearing stats")
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
def _training_or_validation_epoch(
|
def _training_or_validation_epoch(
|
||||||
self,
|
self,
|
||||||
epoch: int,
|
epoch: int,
|
||||||
|
@ -124,14 +124,32 @@ data_source_ImplicitronDataSource_args:
|
|||||||
dataset_length_val: 0
|
dataset_length_val: 0
|
||||||
dataset_length_test: 0
|
dataset_length_test: 0
|
||||||
model_factory_ImplicitronModelFactory_args:
|
model_factory_ImplicitronModelFactory_args:
|
||||||
force_load: false
|
resume: true
|
||||||
model_class_type: GenericModel
|
model_class_type: GenericModel
|
||||||
resume: false
|
|
||||||
resume_epoch: -1
|
resume_epoch: -1
|
||||||
visdom_env: ''
|
force_resume: false
|
||||||
visdom_port: 8097
|
|
||||||
visdom_server: http://127.0.0.1
|
|
||||||
model_GenericModel_args:
|
model_GenericModel_args:
|
||||||
|
log_vars:
|
||||||
|
- loss_rgb_psnr_fg
|
||||||
|
- loss_rgb_psnr
|
||||||
|
- loss_rgb_mse
|
||||||
|
- loss_rgb_huber
|
||||||
|
- loss_depth_abs
|
||||||
|
- loss_depth_abs_fg
|
||||||
|
- loss_mask_neg_iou
|
||||||
|
- loss_mask_bce
|
||||||
|
- loss_mask_beta_prior
|
||||||
|
- loss_eikonal
|
||||||
|
- loss_density_tv
|
||||||
|
- loss_depth_neg_penalty
|
||||||
|
- loss_autodecoder_norm
|
||||||
|
- loss_prev_stage_rgb_mse
|
||||||
|
- loss_prev_stage_rgb_psnr_fg
|
||||||
|
- loss_prev_stage_rgb_psnr
|
||||||
|
- loss_prev_stage_mask_bce
|
||||||
|
- objective
|
||||||
|
- epoch
|
||||||
|
- sec/it
|
||||||
mask_images: true
|
mask_images: true
|
||||||
mask_depths: true
|
mask_depths: true
|
||||||
render_image_width: 400
|
render_image_width: 400
|
||||||
@ -162,27 +180,6 @@ model_factory_ImplicitronModelFactory_args:
|
|||||||
loss_prev_stage_rgb_mse: 1.0
|
loss_prev_stage_rgb_mse: 1.0
|
||||||
loss_mask_bce: 0.0
|
loss_mask_bce: 0.0
|
||||||
loss_prev_stage_mask_bce: 0.0
|
loss_prev_stage_mask_bce: 0.0
|
||||||
log_vars:
|
|
||||||
- loss_rgb_psnr_fg
|
|
||||||
- loss_rgb_psnr
|
|
||||||
- loss_rgb_mse
|
|
||||||
- loss_rgb_huber
|
|
||||||
- loss_depth_abs
|
|
||||||
- loss_depth_abs_fg
|
|
||||||
- loss_mask_neg_iou
|
|
||||||
- loss_mask_bce
|
|
||||||
- loss_mask_beta_prior
|
|
||||||
- loss_eikonal
|
|
||||||
- loss_density_tv
|
|
||||||
- loss_depth_neg_penalty
|
|
||||||
- loss_autodecoder_norm
|
|
||||||
- loss_prev_stage_rgb_mse
|
|
||||||
- loss_prev_stage_rgb_psnr_fg
|
|
||||||
- loss_prev_stage_rgb_psnr
|
|
||||||
- loss_prev_stage_mask_bce
|
|
||||||
- objective
|
|
||||||
- epoch
|
|
||||||
- sec/it
|
|
||||||
global_encoder_HarmonicTimeEncoder_args:
|
global_encoder_HarmonicTimeEncoder_args:
|
||||||
n_harmonic_functions: 10
|
n_harmonic_functions: 10
|
||||||
append_input: true
|
append_input: true
|
||||||
@ -422,8 +419,6 @@ optimizer_factory_ImplicitronOptimizerFactory_args:
|
|||||||
lr_policy: MultiStepLR
|
lr_policy: MultiStepLR
|
||||||
momentum: 0.9
|
momentum: 0.9
|
||||||
multistep_lr_milestones: []
|
multistep_lr_milestones: []
|
||||||
resume: false
|
|
||||||
resume_epoch: -1
|
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
training_loop_ImplicitronTrainingLoop_args:
|
training_loop_ImplicitronTrainingLoop_args:
|
||||||
eval_only: false
|
eval_only: false
|
||||||
@ -437,6 +432,9 @@ training_loop_ImplicitronTrainingLoop_args:
|
|||||||
clip_grad: 0.0
|
clip_grad: 0.0
|
||||||
metric_print_interval: 5
|
metric_print_interval: 5
|
||||||
visualize_interval: 1000
|
visualize_interval: 1000
|
||||||
|
visdom_env: ''
|
||||||
|
visdom_port: 8097
|
||||||
|
visdom_server: http://127.0.0.1
|
||||||
evaluator_ImplicitronEvaluator_args:
|
evaluator_ImplicitronEvaluator_args:
|
||||||
camera_difficulty_bin_breaks:
|
camera_difficulty_bin_breaks:
|
||||||
- 0.97
|
- 0.97
|
||||||
|
@ -5,6 +5,7 @@
|
|||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -172,3 +173,48 @@ class TestNerfRepro(unittest.TestCase):
|
|||||||
experiment_runner = experiment.Experiment(**cfg)
|
experiment_runner = experiment.Experiment(**cfg)
|
||||||
experiment.dump_cfg(cfg)
|
experiment.dump_cfg(cfg)
|
||||||
experiment_runner.run()
|
experiment_runner.run()
|
||||||
|
|
||||||
|
@unittest.skip("This test checks resuming of the NeRF training.")
|
||||||
|
def test_nerf_blender_resume(self):
|
||||||
|
# Train one train batch of NeRF, then resume for one more batch.
|
||||||
|
# Set env vars BLENDER_DATASET_ROOT and BLENDER_SINGLESEQ_CLASS first!
|
||||||
|
if not interactive_testing_requested():
|
||||||
|
return
|
||||||
|
with initialize_config_dir(config_dir=str(IMPLICITRON_CONFIGS_DIR)):
|
||||||
|
with tempfile.TemporaryDirectory() as exp_dir:
|
||||||
|
cfg = compose(config_name="repro_singleseq_nerf_blender", overrides=[])
|
||||||
|
cfg.exp_dir = exp_dir
|
||||||
|
|
||||||
|
# set dataset len to 1
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
(
|
||||||
|
cfg
|
||||||
|
.data_source_ImplicitronDataSource_args
|
||||||
|
.data_loader_map_provider_SequenceDataLoaderMapProvider_args
|
||||||
|
.dataset_length_train
|
||||||
|
) = 1
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
# run for one epoch
|
||||||
|
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 1
|
||||||
|
experiment_runner = experiment.Experiment(**cfg)
|
||||||
|
experiment.dump_cfg(cfg)
|
||||||
|
experiment_runner.run()
|
||||||
|
|
||||||
|
# update num epochs + 2, let the optimizer resume
|
||||||
|
cfg.training_loop_ImplicitronTrainingLoop_args.max_epochs = 3
|
||||||
|
experiment_runner = experiment.Experiment(**cfg)
|
||||||
|
experiment_runner.run()
|
||||||
|
|
||||||
|
# start from scratch
|
||||||
|
cfg.model_factory_ImplicitronModelFactory_args.resume = False
|
||||||
|
experiment_runner = experiment.Experiment(**cfg)
|
||||||
|
experiment_runner.run()
|
||||||
|
|
||||||
|
# force resume from epoch 1
|
||||||
|
cfg.model_factory_ImplicitronModelFactory_args.resume = True
|
||||||
|
cfg.model_factory_ImplicitronModelFactory_args.force_resume = True
|
||||||
|
cfg.model_factory_ImplicitronModelFactory_args.resume_epoch = 1
|
||||||
|
experiment_runner = experiment.Experiment(**cfg)
|
||||||
|
experiment_runner.run()
|
||||||
|
@ -344,7 +344,7 @@ def export_scenes(
|
|||||||
|
|
||||||
# Load the previously trained model
|
# Load the previously trained model
|
||||||
experiment = Experiment(config)
|
experiment = Experiment(config)
|
||||||
model = experiment.model_factory(force_load=True, load_model_only=True)
|
model = experiment.model_factory(force_resume=True)
|
||||||
model.cuda()
|
model.cuda()
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
# This source code is licensed under the BSD-style license found in the
|
# This source code is licensed under the BSD-style license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -45,6 +45,10 @@ class ImplicitronModelBase(ReplaceableBase, torch.nn.Module):
|
|||||||
optimization.
|
optimization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# The keys from `preds` (output of ImplicitronModelBase.forward) to be logged in
|
||||||
|
# the training loop.
|
||||||
|
log_vars: List[str] = field(default_factory=lambda: ["objective"])
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -1,3 +1,24 @@
|
|||||||
|
log_vars:
|
||||||
|
- loss_rgb_psnr_fg
|
||||||
|
- loss_rgb_psnr
|
||||||
|
- loss_rgb_mse
|
||||||
|
- loss_rgb_huber
|
||||||
|
- loss_depth_abs
|
||||||
|
- loss_depth_abs_fg
|
||||||
|
- loss_mask_neg_iou
|
||||||
|
- loss_mask_bce
|
||||||
|
- loss_mask_beta_prior
|
||||||
|
- loss_eikonal
|
||||||
|
- loss_density_tv
|
||||||
|
- loss_depth_neg_penalty
|
||||||
|
- loss_autodecoder_norm
|
||||||
|
- loss_prev_stage_rgb_mse
|
||||||
|
- loss_prev_stage_rgb_psnr_fg
|
||||||
|
- loss_prev_stage_rgb_psnr
|
||||||
|
- loss_prev_stage_mask_bce
|
||||||
|
- objective
|
||||||
|
- epoch
|
||||||
|
- sec/it
|
||||||
mask_images: true
|
mask_images: true
|
||||||
mask_depths: true
|
mask_depths: true
|
||||||
render_image_width: 400
|
render_image_width: 400
|
||||||
@ -28,27 +49,6 @@ loss_weights:
|
|||||||
loss_prev_stage_rgb_mse: 1.0
|
loss_prev_stage_rgb_mse: 1.0
|
||||||
loss_mask_bce: 0.0
|
loss_mask_bce: 0.0
|
||||||
loss_prev_stage_mask_bce: 0.0
|
loss_prev_stage_mask_bce: 0.0
|
||||||
log_vars:
|
|
||||||
- loss_rgb_psnr_fg
|
|
||||||
- loss_rgb_psnr
|
|
||||||
- loss_rgb_mse
|
|
||||||
- loss_rgb_huber
|
|
||||||
- loss_depth_abs
|
|
||||||
- loss_depth_abs_fg
|
|
||||||
- loss_mask_neg_iou
|
|
||||||
- loss_mask_bce
|
|
||||||
- loss_mask_beta_prior
|
|
||||||
- loss_eikonal
|
|
||||||
- loss_density_tv
|
|
||||||
- loss_depth_neg_penalty
|
|
||||||
- loss_autodecoder_norm
|
|
||||||
- loss_prev_stage_rgb_mse
|
|
||||||
- loss_prev_stage_rgb_psnr_fg
|
|
||||||
- loss_prev_stage_rgb_psnr
|
|
||||||
- loss_prev_stage_mask_bce
|
|
||||||
- objective
|
|
||||||
- epoch
|
|
||||||
- sec/it
|
|
||||||
global_encoder_SequenceAutodecoder_args:
|
global_encoder_SequenceAutodecoder_args:
|
||||||
autodecoder_args:
|
autodecoder_args:
|
||||||
encoding_dim: 0
|
encoding_dim: 0
|
||||||
|
@ -90,5 +90,5 @@ class TestGenericModel(unittest.TestCase):
|
|||||||
remove_unused_components(instance_args)
|
remove_unused_components(instance_args)
|
||||||
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
yaml = OmegaConf.to_yaml(instance_args, sort_keys=False)
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
(DATA_DIR / "overrides.yaml_").write_text(yaml)
|
(DATA_DIR / "overrides_.yaml").write_text(yaml)
|
||||||
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
|
self.assertEqual(yaml, (DATA_DIR / "overrides.yaml").read_text())
|
||||||
|
Loading…
x
Reference in New Issue
Block a user